/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.redshift.plugin;

import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.RedshiftProperty;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.logger.RedshiftLogger;
import com.amazon.redshift.plugin.utils.RequestUtils;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithWebIdentityRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithWebIdentityResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import com.amazonaws.util.json.Jackson;
import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.net.URL;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;

public abstract class JwtCredentialsProvider
implements IPlugin {
    private static final String KEY_ROLE_ARN = "roleArn";
    private static final String KEY_WEB_IDENTITY_TOKEN = "webIdentityToken";
    private static final String KEY_DURATION = "duration";
    private static final String KEY_ROLE_SESSION_NAME = "roleSessionName";
    private static final String DEFAULT_ROLE_SESSION_NAME = "jwt_redshift_session";
    protected String m_roleArn;
    protected String m_jwt;
    protected String m_roleSessionName = "jwt_redshift_session";
    protected int m_duration;
    protected String m_dbUser;
    protected String m_stsEndpoint;
    protected String m_region;
    protected RedshiftLogger m_log;
    protected Boolean m_disableCache = false;
    protected Boolean m_groupFederation = false;
    private static Map<String, CredentialsHolder> m_cache = new HashMap<String, CredentialsHolder>();
    private CredentialsHolder m_lastRefreshCredentials;
    private static final Class<?> CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(JwtCredentialsProvider.class.getClassLoader()){

        @Override
        public Class<?> loadClass(String name) throws ClassNotFoundException {
            Class<?> clazz = this.getParent().loadClass(name);
            if (LogFactory.class.isAssignableFrom(clazz)) {
                return CUSTOM_LOG_FACTORY_CLASS;
            }
            return clazz;
        }

        @Override
        public Enumeration<URL> getResources(String name) throws IOException {
            if ("commons-logging.properties".equals(name)) {
                return Collections.enumeration(Collections.emptyList());
            }
            return super.getResources(name);
        }

        @Override
        public URL getResource(String name) {
            if (JwtCredentialsProvider.LOG_PROPERTIES_FILE_PATH.equals(name)) {
                return JwtCredentialsProvider.class.getResource(JwtCredentialsProvider.LOG_PROPERTIES_FILE_NAME);
            }
            return super.getResource(name);
        }
    };

    protected abstract String processJwt(String var1) throws IOException;

    @Override
    public void addParameter(String key, String value) {
        if (RedshiftLogger.isEnable()) {
            this.m_log.logDebug("key: {0}", key);
        }
        if (KEY_ROLE_ARN.equalsIgnoreCase(key)) {
            this.m_roleArn = value;
        } else if (KEY_WEB_IDENTITY_TOKEN.equalsIgnoreCase(key)) {
            this.m_jwt = value;
        } else if (KEY_ROLE_SESSION_NAME.equalsIgnoreCase(key)) {
            this.m_roleSessionName = value;
        } else if (KEY_DURATION.equalsIgnoreCase(key)) {
            this.m_duration = Integer.parseInt(value);
        } else if (!RedshiftProperty.DB_USER.getName().equalsIgnoreCase(key)) {
            if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(key)) {
                this.m_region = value;
            } else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(key)) {
                this.m_stsEndpoint = value;
            } else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(key)) {
                this.m_disableCache = Boolean.valueOf(value);
            }
        }
    }

    @Override
    public void setLogger(RedshiftLogger log) {
        this.m_log = log;
    }

    @Override
    public int getSubType() {
        return 2;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CredentialsHolder getCredentials() {
        CredentialsHolder credentials = null;
        if (!this.m_disableCache.booleanValue()) {
            String key = this.getCacheKey();
            credentials = m_cache.get(key);
        }
        if (credentials == null || credentials.isExpired()) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("JWT getCredentials NOT from cache", new Object[0]);
            }
            JwtCredentialsProvider jwtCredentialsProvider = this;
            synchronized (jwtCredentialsProvider) {
                this.refresh();
                if (this.m_disableCache.booleanValue()) {
                    credentials = this.m_lastRefreshCredentials;
                    this.m_lastRefreshCredentials = null;
                }
            }
        } else {
            credentials.setRefresh(false);
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("SAML getCredentials from cache", new Object[0]);
            }
        }
        if (!this.m_disableCache.booleanValue()) {
            credentials = m_cache.get(this.getCacheKey());
        }
        if (credentials == null) {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
        return credentials;
    }

    public void refresh() {
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            String jwt = this.processJwt(this.m_jwt);
            if (RedshiftLogger.isEnable()) {
                this.m_log.logDebug(String.format("JWT : %s", jwt), new Object[0]);
            }
            String[] decodedjwt = this.decodeJwt(this.m_jwt);
            this.m_dbUser = this.deriveDatabaseUser(decodedjwt);
            AssumeRoleWithWebIdentityRequest jwtRequest = new AssumeRoleWithWebIdentityRequest();
            jwtRequest.setWebIdentityToken(jwt);
            jwtRequest.setRoleArn(this.m_roleArn);
            jwtRequest.setRoleSessionName(this.m_roleSessionName);
            if (this.m_duration > 0) {
                jwtRequest.setDurationSeconds(Integer.valueOf(this.m_duration));
            }
            AWSStaticCredentialsProvider p = new AWSStaticCredentialsProvider((AWSCredentials)new AnonymousAWSCredentials());
            AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard();
            AWSSecurityTokenService stsSvc = RequestUtils.buildSts(this.m_stsEndpoint, this.m_region, builder, (AWSCredentialsProvider)p, this.m_log);
            AssumeRoleWithWebIdentityResult result = stsSvc.assumeRoleWithWebIdentity(jwtRequest);
            Credentials cred = result.getCredentials();
            Date expiration = cred.getExpiration();
            BasicSessionCredentials c = new BasicSessionCredentials(cred.getAccessKeyId(), cred.getSecretAccessKey(), cred.getSessionToken());
            CredentialsHolder credentials = CredentialsHolder.newInstance((AWSCredentials)c, expiration);
            credentials.setMetadata(this.readMetadata());
            credentials.setRefresh(true);
            if (!this.m_disableCache.booleanValue()) {
                m_cache.put(this.getCacheKey(), credentials);
            } else {
                this.m_lastRefreshCredentials = credentials;
            }
        }
        catch (Exception e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("JWT error: " + e.getMessage(), (Throwable)e);
        }
        finally {
            currentThread.setContextClassLoader(cl);
        }
    }

    @Override
    public String getPluginSpecificCacheKey() {
        return "";
    }

    @Override
    public String getIdpToken() {
        String jwt = null;
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            jwt = this.processJwt(this.m_jwt);
            if (RedshiftLogger.isEnable()) {
                this.m_log.logDebug(String.format("JWT : %s", jwt), new Object[0]);
            }
        }
        catch (Exception e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("JWT error: " + e.getMessage(), (Throwable)e);
        }
        finally {
            currentThread.setContextClassLoader(cl);
        }
        return jwt;
    }

    @Override
    public void setGroupFederation(boolean groupFederation) {
        this.m_groupFederation = groupFederation;
    }

    @Override
    public String getCacheKey() {
        String pluginSpecificKey = this.getPluginSpecificCacheKey();
        return this.m_roleArn + this.m_jwt + this.m_roleSessionName + this.m_duration + pluginSpecificKey;
    }

    protected void checkRequiredParameters() throws IOException {
        if (StringUtils.isNullOrEmpty((String)this.m_roleArn)) {
            throw new IOException("Missing required property: roleArn");
        }
        if (StringUtils.isNullOrEmpty((String)this.m_jwt)) {
            throw new IOException("Missing required property: webIdentityToken");
        }
    }

    protected String[] decodeJwt(String jwt) {
        if (jwt == null) {
            return null;
        }
        String[] headerPayloadSig = jwt.split("\\.");
        if (headerPayloadSig.length == 3) {
            String header = new String(Base64.decodeBase64((String)headerPayloadSig[0]));
            String payload = new String(Base64.decodeBase64((String)headerPayloadSig[1]));
            String signature = headerPayloadSig[2];
            if (RedshiftLogger.isEnable()) {
                this.m_log.logDebug(String.format("Decoded JWT : Header: %s payload: %s signature:%s", header, payload, signature), new Object[0]);
            }
            return new String[]{header, payload, signature};
        }
        return null;
    }

    protected String deriveDatabaseUser(String[] decodedJwt) {
        String databaseUser = null;
        if (decodedJwt != null && decodedJwt.length == 3) {
            String payload = decodedJwt[1];
            String[] claims = new String[]{"DbUser", "upn", "preferred_username", "email"};
            JsonNode entityJson = Jackson.jsonNodeOf((String)payload);
            for (String claim : claims) {
                JsonNode userTokenField = entityJson.findValue(claim);
                if (userTokenField == null || StringUtils.isNullOrEmpty((String)(databaseUser = userTokenField.textValue()))) continue;
                if (!RedshiftLogger.isEnable()) break;
                this.m_log.logDebug(String.format("JWT claim: %s as database user: %s", claim, databaseUser), new Object[0]);
                break;
            }
            if (StringUtils.isNullOrEmpty(databaseUser)) {
                throw new SdkClientException("No database user claim found in JWT");
            }
            return databaseUser;
        }
        throw new SdkClientException("JWT decoding error");
    }

    private CredentialsHolder.IamMetadata readMetadata() {
        CredentialsHolder.IamMetadata metadata = new CredentialsHolder.IamMetadata();
        metadata.setDbUser(this.m_dbUser);
        metadata.setAutoCreate(true);
        return metadata;
    }
}

