/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.athena.jdbc.authentication.oidc;

import com.amazon.athena.jdbc.authentication.oidc.OpenIdConnectWellKnownConfigurationService;
import com.amazon.athena.logging.AthenaLogger;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Base64;
import java.util.Optional;
import software.amazon.awssdk.protocols.jsoncore.JsonNode;
import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser;
import software.amazon.awssdk.utils.StringUtils;

public class JwtTokenValidator {
    private static final AthenaLogger logger = AthenaLogger.of(JwtTokenValidator.class);
    private static final int CLOCK_SKEW_TOLERANCE = 300;
    private static final int OLD_TOKEN_TOLERANCE = 86400;
    private final OpenIdConnectWellKnownConfigurationService openIdConnectWellKnownConfigurationService;
    private final String clientId;

    public JwtTokenValidator(String clientId, OpenIdConnectWellKnownConfigurationService openIdConnectWellKnownConfigurationService) {
        this.openIdConnectWellKnownConfigurationService = openIdConnectWellKnownConfigurationService;
        this.clientId = clientId;
    }

    public boolean verifyToken(String token) {
        try {
            String[] chunks = token.split("\\.");
            if (chunks.length != 3) {
                logger.warn("web identity token had {} parts, expected exactly three.", chunks.length);
                return false;
            }
            String header = chunks[0];
            String payload = chunks[1];
            JsonNode headerJson = this.decodeToJsonNode(header);
            JsonNode claims = this.decodeToJsonNode(payload);
            return this.verifySigned(headerJson) && this.verifyClaims(claims);
        }
        catch (JwtValidationException e) {
            logger.warn("web identity token verification failed.", e);
            return false;
        }
    }

    private JsonNode decodeToJsonNode(String encodedString) throws JwtValidationException {
        try {
            byte[] decodedBytes = Base64.getUrlDecoder().decode(encodedString);
            String decodedString = new String(decodedBytes, StandardCharsets.UTF_8);
            return JsonNodeParser.create().parse(decodedString);
        }
        catch (UncheckedIOException | IllegalArgumentException ex) {
            throw new JwtValidationException("Unable to decode the token " + ex.getMessage());
        }
    }

    private boolean verifySigned(JsonNode header) throws JwtValidationException {
        Optional<String> algHeaderField = JwtTokenValidator.extractStringField(header, "alg");
        if (!algHeaderField.isPresent() || algHeaderField.get().equalsIgnoreCase("none")) {
            logger.warn("web identity token verification failed, no algorithm header found.", new Object[0]);
            throw new JwtValidationException("Unable to verify if token is signed");
        }
        return true;
    }

    private boolean verifyClaims(JsonNode claims) throws JwtValidationException {
        String audience = JwtTokenValidator.extractStringField(claims, "aud").orElseThrow(() -> new JwtValidationException("Missing audience claim"));
        long expirationTime = JwtTokenValidator.extractLongField(claims, "exp").orElseThrow(() -> new JwtValidationException("Missing expiration claim"));
        long issuedAt = JwtTokenValidator.extractLongField(claims, "iat").orElseThrow(() -> new JwtValidationException("Missing issued at claim"));
        long nbf = JwtTokenValidator.extractLongField(claims, "nbf").orElse(-1L);
        if (!this.clientId.equals(audience)) {
            logger.warn("web identity token audience verification failed", new Object[0]);
            throw new JwtValidationException("web identity token audience verification failed");
        }
        if (Instant.now().isAfter(Instant.ofEpochSecond(expirationTime))) {
            logger.warn("web identity token has already expired", new Object[0]);
            throw new JwtValidationException("web identity token has already expired");
        }
        Instant issuedInstant = Instant.ofEpochSecond(issuedAt);
        if (issuedInstant.isAfter(Instant.now().plusSeconds(300L)) || issuedInstant.isBefore(Instant.now().minusSeconds(86400L))) {
            logger.warn("web identity token issued at time is either too old or in future", new Object[0]);
            throw new JwtValidationException("web identity token issued at time is either too old or in future");
        }
        if (nbf != -1L && Instant.now().isBefore(Instant.ofEpochSecond(nbf))) {
            logger.warn("web identity token not yet valid", new Object[0]);
            throw new JwtValidationException("web identity token not yet valid");
        }
        return true;
    }

    private static Optional<String> extractStringField(JsonNode jsonNode, String fieldName) {
        return jsonNode.field(fieldName).map(JsonNode::text).filter(text -> !StringUtils.isEmpty(text));
    }

    private static Optional<Long> extractLongField(JsonNode jsonNode, String fieldName) {
        return jsonNode.field(fieldName).map(node -> Long.parseLong(node.text()));
    }

    private static class JwtValidationException
    extends Exception {
        public JwtValidationException(String message) {
            super(message);
        }
    }
}

