package com.amazon.athena.jdbc.authentication;

import com.amazon.athena.jdbc.configuration.ConnectionParameter;
import com.amazon.athena.jdbc.configuration.ConnectionParameters;
import com.amazon.athena.jdbc.support.AuthenticationException;
import com.amazon.athena.jdbc.support.EndpointHelper;
import com.amazon.athena.jdbc.support.ProxyHelper;
import com.amazon.athena.logging.AthenaLogger;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.time.Clock;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.codec.binary.Base64;
import org.w3c.dom.Document;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.core.internal.useragent.UserAgentConstant;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lakeformation.LakeFormationClient;
import software.amazon.awssdk.services.lakeformation.LakeFormationClientBuilder;
import software.amazon.awssdk.services.lakeformation.model.AssumeDecoratedRoleWithSamlRequest;
import software.amazon.awssdk.services.lakeformation.model.AssumeDecoratedRoleWithSamlResponse;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlResponse;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.utils.Pair;

/* loaded from: input_file:com/amazon/athena/jdbc/authentication/SamlCredentialsProvider.class */
abstract class SamlCredentialsProvider extends IdpCredentialsProvider implements AwsCredentialsProvider {
    private static final AthenaLogger logger = AthenaLogger.of(SamlCredentialsProvider.class);
    private static final String ROLE_PATTERN = "arn:aws[-a-z]*:iam::\\d*:role/\\S+";
    private static final String SAML_PROVIDER_PATTERN = "arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+";
    private static final int EXPIRATION_THRESHOLD_SECS = 180;
    private final AssumeRoleWithSamlRequest.Builder assumeRoleWithSamlRequestFactory;
    private final AssumeDecoratedRoleWithSamlRequest.Builder assumeDecoratedRoleWithSamlRequestFactory;
    private final StsClientBuilder stsClientFactory;
    private final LakeFormationClientBuilder lakeFormationClientFactory;
    private final DocumentBuilderFactory documentBuilderFactory;
    private final String preferredRole;
    private final Integer roleSessionDuration;
    private final Region region;
    private final boolean lakeFormationEnabled;
    private Map<ConnectionParameter<?>, String> parameters;
    private Credentials stsCredentials;
    private AssumeDecoratedRoleWithSamlResponse lakeFormationCredentials;
    protected final Clock clock;

    /* JADX INFO: Access modifiers changed from: protected */
    public SamlCredentialsProvider(AssumeRoleWithSamlRequest.Builder builder, AssumeDecoratedRoleWithSamlRequest.Builder builder2, StsClientBuilder stsClientBuilder, LakeFormationClientBuilder lakeFormationClientBuilder, DocumentBuilderFactory documentBuilderFactory, Clock clock, String str, Integer num, Region region, boolean z, Map<ConnectionParameter<?>, String> map) {
        this.assumeRoleWithSamlRequestFactory = builder == null ? AssumeRoleWithSamlRequest.builder() : builder;
        this.assumeDecoratedRoleWithSamlRequestFactory = builder2 == null ? AssumeDecoratedRoleWithSamlRequest.builder() : builder2;
        this.stsClientFactory = stsClientBuilder == null ? StsClient.builder() : stsClientBuilder;
        this.lakeFormationClientFactory = lakeFormationClientBuilder == null ? LakeFormationClient.builder() : lakeFormationClientBuilder;
        this.documentBuilderFactory = documentBuilderFactory == null ? DocumentBuilderFactory.newInstance() : documentBuilderFactory;
        this.clock = clock == null ? Clock.systemDefaultZone() : clock;
        this.preferredRole = str;
        this.roleSessionDuration = num;
        this.region = region;
        this.lakeFormationEnabled = z;
        this.parameters = map;
    }

    protected abstract String getSamlAssertion();

    public AwsCredentials resolveCredentials() {
        if (this.lakeFormationEnabled) {
            if (this.lakeFormationCredentials == null || this.lakeFormationCredentials.expiration().compareTo(this.clock.instant().plusSeconds(180L)) < 0) {
                this.lakeFormationCredentials = obtainCredentialsFromLakeFormation(getSamlAssertion());
            }
            return AwsSessionCredentials.create(this.lakeFormationCredentials.accessKeyId(), this.lakeFormationCredentials.secretAccessKey(), this.lakeFormationCredentials.sessionToken());
        }
        if (this.stsCredentials == null || this.stsCredentials.expiration().compareTo(this.clock.instant().plusSeconds(180L)) < 0) {
            this.stsCredentials = obtainCredentialsFromSts(getSamlAssertion());
        }
        return AwsSessionCredentials.create(this.stsCredentials.accessKeyId(), this.stsCredentials.secretAccessKey(), this.stsCredentials.sessionToken());
    }

    private Credentials obtainCredentialsFromSts(String str) {
        Optional<URI> stsEndpoint = getStsEndpoint();
        StsClientBuilder stsClientBuilder = this.stsClientFactory;
        stsClientBuilder.getClass();
        stsEndpoint.ifPresent(stsClientBuilder::endpointOverride);
        ProxyHelper.getSyncProxyConfiguration(this.parameters).ifPresent(proxyConfiguration -> {
        });
        StsClient build = ((StsClientBuilder) ((StsClientBuilder) this.stsClientFactory.region(this.region)).credentialsProvider((AwsCredentialsProvider) AnonymousCredentialsProvider.create())).mo1385build();
        Pair<String, String> extractRoleAndPrincipal = extractRoleAndPrincipal(str);
        AssumeRoleWithSamlRequest assumeRoleWithSamlRequest = (AssumeRoleWithSamlRequest) this.assumeRoleWithSamlRequestFactory.samlAssertion(str).roleArn(extractRoleAndPrincipal.left()).principalArn(extractRoleAndPrincipal.right()).durationSeconds(this.roleSessionDuration).mo1385build();
        logger.debug("Obtaining credentials from STS", new Object[0]);
        logger.trace("Sending AssumeRoleWithSaml request: {}", assumeRoleWithSamlRequest);
        AssumeRoleWithSamlResponse assumeRoleWithSAML = build.assumeRoleWithSAML(assumeRoleWithSamlRequest);
        logger.info("Obtained credentials from STS", new Object[0]);
        return assumeRoleWithSAML.credentials();
    }

    private AssumeDecoratedRoleWithSamlResponse obtainCredentialsFromLakeFormation(String str) {
        Optional<URI> lakeFormationEndpoint = getLakeFormationEndpoint();
        LakeFormationClientBuilder lakeFormationClientBuilder = this.lakeFormationClientFactory;
        lakeFormationClientBuilder.getClass();
        lakeFormationEndpoint.ifPresent(lakeFormationClientBuilder::endpointOverride);
        ProxyHelper.getSyncProxyConfiguration(this.parameters).ifPresent(proxyConfiguration -> {
        });
        LakeFormationClient build = ((LakeFormationClientBuilder) ((LakeFormationClientBuilder) this.lakeFormationClientFactory.region(this.region)).credentialsProvider((AwsCredentialsProvider) AnonymousCredentialsProvider.create())).mo1385build();
        Pair<String, String> extractRoleAndPrincipal = extractRoleAndPrincipal(str);
        AssumeDecoratedRoleWithSamlRequest assumeDecoratedRoleWithSamlRequest = (AssumeDecoratedRoleWithSamlRequest) this.assumeDecoratedRoleWithSamlRequestFactory.samlAssertion(str).roleArn(extractRoleAndPrincipal.left()).principalArn(extractRoleAndPrincipal.right()).durationSeconds(this.roleSessionDuration).mo1385build();
        logger.debug("Obtaining credentials from Lake Formation", new Object[0]);
        logger.trace("Sending AssumeDecoratedRoleWithSaml request: {}", String.format("AssumeDecoratedRoleWithSamlRequest(SAMLAssertion=*******, RoleArn=%s, PrincipalArn=%s, DurationSeconds=%s)", assumeDecoratedRoleWithSamlRequest.roleArn(), assumeDecoratedRoleWithSamlRequest.principalArn(), assumeDecoratedRoleWithSamlRequest.durationSeconds()));
        AssumeDecoratedRoleWithSamlResponse assumeDecoratedRoleWithSAML = build.assumeDecoratedRoleWithSAML(assumeDecoratedRoleWithSamlRequest);
        logger.info("Obtained credentials from Lake Formation", new Object[0]);
        return assumeDecoratedRoleWithSAML;
    }

    private Optional<URI> getStsEndpoint() {
        return ConnectionParameters.STS_ENDPOINT_PARAMETER.findValue(this.parameters).map(str -> {
            return EndpointHelper.constructEndpointUri(str, "STS");
        });
    }

    private Optional<URI> getLakeFormationEndpoint() {
        return ConnectionParameters.LAKE_FORMATION_ENDPOINT_PARAMETER.findValue(this.parameters).map(str -> {
            return EndpointHelper.constructEndpointUri(str, "Lake Formation");
        });
    }

    private ApacheHttpClient.Builder getHttpClientBuilder(ProxyConfiguration proxyConfiguration) {
        return ApacheHttpClient.builder().proxyConfiguration(proxyConfiguration);
    }

    private Pair<String, String> extractRoleAndPrincipal(String str) {
        return findPreferredRoleAndPrincipal(str, findIamRolesAndPrincipals(findRoleSamlAttributes(str)));
    }

    private Document parseIntoDom(byte[] bArr) {
        try {
            this.documentBuilderFactory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
            this.documentBuilderFactory.setXIncludeAware(false);
            this.documentBuilderFactory.setExpandEntityReferences(false);
            this.documentBuilderFactory.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
            this.documentBuilderFactory.setFeature("http://xml.org/sax/features/external-general-entities", false);
            return this.documentBuilderFactory.newDocumentBuilder().parse(new ByteArrayInputStream(bArr));
        } catch (IOException | ParserConfigurationException | SAXException e) {
            throw new AuthenticationException("An error occurred while parsing the SAML assertion into DOM", e);
        }
    }

    private NodeList findRoleSamlAttributes(String str) {
        try {
            NodeList nodeList = (NodeList) XPathFactory.newInstance().newXPath().compile("//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()").evaluate(parseIntoDom(Base64.decodeBase64(str)), XPathConstants.NODESET);
            if (nodeList.getLength() != 0) {
                return nodeList;
            }
            logger.error("No role attribute found in the SAML assertion -- " + new String(Base64.decodeBase64(str)), new Object[0]);
            throw new AuthenticationException("No role attribute found in the SAML assertion");
        } catch (XPathExpressionException e) {
            throw new AuthenticationException("An error occurred while attempting to find the SAML role attribute", e);
        }
    }

    private static Map<String, String> findIamRolesAndPrincipals(NodeList nodeList) {
        HashMap hashMap = new HashMap();
        if (nodeList != null) {
            for (int i = 0; i < nodeList.getLength(); i++) {
                String[] split = nodeList.item(i).getNodeValue().split(UserAgentConstant.COMMA);
                Optional findAny = Stream.of((Object[]) split).filter(str -> {
                    return str.matches(ROLE_PATTERN);
                }).findAny();
                Optional findAny2 = Stream.of((Object[]) split).filter(str2 -> {
                    return str2.matches(SAML_PROVIDER_PATTERN);
                }).findAny();
                if (findAny.isPresent() && findAny2.isPresent()) {
                    hashMap.put(findAny.get(), findAny2.get());
                }
            }
        }
        return hashMap;
    }

    private Pair<String, String> findPreferredRoleAndPrincipal(String str, Map<String, String> map) {
        if (this.preferredRole == null) {
            return (Pair) map.entrySet().stream().findFirst().map(entry -> {
                return Pair.of(entry.getKey(), entry.getValue());
            }).orElseThrow(() -> {
                return new AuthenticationException("None of the role attributes in the SAML assertion contain a role");
            });
        }
        String str2 = map.get(this.preferredRole);
        if (str2 != null) {
            return Pair.of(this.preferredRole, str2);
        }
        logger.error("Preferred role not found in SAML assertion -- " + new String(Base64.decodeBase64(str)), new Object[0]);
        throw new AuthenticationException("Preferred role not found in SAML assertion");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String decodeHtmlCharacterReferences(String str) {
        StringBuilder sb = new StringBuilder(str.length());
        int i = 0;
        int length = str.length();
        while (i < length) {
            char charAt = str.charAt(i);
            if (charAt != '&') {
                sb.append(charAt);
                i++;
            } else if (str.startsWith("&amp;", i)) {
                sb.append('&');
                i += 5;
            } else if (str.startsWith("&apos;", i)) {
                sb.append('\'');
                i += 6;
            } else if (str.startsWith("&quot;", i)) {
                sb.append('\"');
                i += 6;
            } else if (str.startsWith("&lt;", i)) {
                sb.append('<');
                i += 4;
            } else if (str.startsWith("&gt;", i)) {
                sb.append('>');
                i += 4;
            } else if (str.startsWith("&#x2b;", i)) {
                sb.append('+');
                i += 6;
            } else if (str.startsWith("&#x3d;", i)) {
                sb.append('=');
                i += 6;
            } else {
                sb.append(charAt);
                i++;
            }
        }
        return sb.toString();
    }
}
