/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.http;

import io.skylite.core.http.HttpRequest;
import io.skylite.core.http.HttpRequestSigner;
import io.skylite.core.http.MutableHttpRequest;
import io.skylite.core.rest.RestRequest;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import java.io.ByteArrayInputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.spi.signer.SignRequest;
import software.amazon.awssdk.http.auth.spi.signer.SignedRequest;
import software.amazon.awssdk.identity.spi.Identity;

public class AwsSigV4HttpRequestSigner
implements HttpRequestSigner {
    private static final Logger log = LogManager.getLogger(AwsSigV4HttpRequestSigner.class);
    public static final String API_NAME = "aws_sigv4";
    public static final String ACCESS_KEY = "access_key";
    public static final String SECRET_KEY = "secret_key";
    public static final String SESSION_TOKEN = "session_token";
    public static final String REGION = "region";
    public static final String SERVICE_NAME = "service_name";
    private static final AwsV4HttpSigner signer = AwsV4HttpSigner.create();

    public Class<?> getClazz() {
        return AwsSigV4HttpRequestSigner.class;
    }

    public String getApiName() {
        return API_NAME;
    }

    public HttpRequest sign(HttpRequest request, Map<String, String> keyedParams) {
        this.validateParams(keyedParams);
        SdkHttpFullRequest sdkRequest = this.toSdkRequest(request);
        SdkHttpFullRequest signedSdkRequest = this.signSdkRequest(sdkRequest, keyedParams);
        return this.fromSdkRequest(signedSdkRequest, request);
    }

    public SdkHttpFullRequest signSdkRequest(SdkHttpFullRequest request, Map<String, String> keyedParams) {
        this.validateParams(keyedParams);
        String accessKey = keyedParams.get(ACCESS_KEY);
        String secretKey = keyedParams.get(SECRET_KEY);
        String sessionToken = keyedParams.get(SESSION_TOKEN);
        String region = keyedParams.get(REGION);
        String serviceName = keyedParams.get(SERVICE_NAME);
        return this.signSdkRequest(request, accessKey, secretKey, sessionToken, serviceName, region);
    }

    public SdkHttpFullRequest signSdkRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String serviceName, String region) {
        AwsBasicCredentials credentials = sessionToken == null ? AwsBasicCredentials.create((String)accessKey, (String)secretKey) : AwsSessionCredentials.create((String)accessKey, (String)secretKey, (String)sessionToken);
        SignedRequest signedRequest = signer.sign(arg_0 -> AwsSigV4HttpRequestSigner.lambda$signSdkRequest$0((AwsCredentials)credentials, request, serviceName, region, arg_0));
        return (SdkHttpFullRequest)signedRequest.request();
    }

    public SdkHttpFullRequest buildSdkRequest(String action, Connector connector, Map<String, String> parameters, String payload, SdkHttpMethod method) {
        URI uri;
        String charset = parameters.getOrDefault("charset", "UTF-8");
        RequestBody requestBody = payload != null ? RequestBody.fromString((String)payload, (Charset)Charset.forName(charset)) : RequestBody.empty();
        if (SdkHttpMethod.POST == method && 0L == (Long)requestBody.optionalContentLength().get() && !action.equals(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.toString())) {
            log.error("Content length is 0. Aborting request to remote model");
            throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
        }
        String endpoint = connector.getActionEndpoint(action, parameters);
        try {
            uri = URI.create(endpoint);
            if (uri.getHost() == null) {
                throw new IllegalArgumentException("Invalid URI. Please check if the endpoint is valid from connector.");
            }
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Encountered error when trying to create uri from endpoint in ml connector. Please update the endpoint in connection configuration: ", e);
        }
        SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(method).uri(uri).contentStreamProvider(requestBody.contentStreamProvider());
        Map headers = connector.getDecryptedHeaders();
        if (headers != null) {
            for (String key : headers.keySet()) {
                builder.putHeader(key, (String)headers.get(key));
            }
        }
        if (builder.matchingHeaders("Content-Type").isEmpty()) {
            builder.putHeader("Content-Type", "application/json");
        }
        if (builder.matchingHeaders("Content-Length").isEmpty()) {
            builder.putHeader("Content-Length", ((Long)requestBody.optionalContentLength().get()).toString());
        }
        return builder.build();
    }

    private void validateParams(Map<String, String> params) {
        if (params == null) {
            throw new IllegalArgumentException("Signing parameters cannot be null");
        }
        this.requireParam(params, ACCESS_KEY);
        this.requireParam(params, SECRET_KEY);
        this.requireParam(params, REGION);
        this.requireParam(params, SERVICE_NAME);
    }

    private void requireParam(Map<String, String> params, String key) {
        if (!params.containsKey(key) || params.get(key) == null || params.get(key).isEmpty()) {
            throw new IllegalArgumentException("Missing required parameter: " + key);
        }
    }

    private SdkHttpFullRequest toSdkRequest(HttpRequest request) {
        URI uri = URI.create(request.uri());
        SdkHttpMethod method = this.toSdkMethod(request.method());
        SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(method).uri(uri);
        for (Map.Entry entry : request.getHeaders().entrySet()) {
            for (String value : (List)entry.getValue()) {
                builder.appendHeader((String)entry.getKey(), value);
            }
        }
        if (request.content() != null && request.content().length() > 0) {
            byte[] contentBytes = new byte[request.content().length()];
            System.arraycopy(request.content().toBytesRef().bytes, request.content().toBytesRef().offset, contentBytes, 0, request.content().length());
            ContentStreamProvider contentProvider = () -> new ByteArrayInputStream(contentBytes);
            builder.contentStreamProvider(contentProvider);
            if (builder.matchingHeaders("Content-Length").isEmpty()) {
                builder.putHeader("Content-Length", String.valueOf(contentBytes.length));
            }
        }
        return builder.build();
    }

    private HttpRequest fromSdkRequest(SdkHttpFullRequest sdkRequest, HttpRequest originalRequest) {
        MutableHttpRequest.Builder builder = MutableHttpRequest.builder().method(originalRequest.method()).uri(sdkRequest.getUri()).content(originalRequest.content()).protocolVersion(originalRequest.protocolVersion());
        for (Map.Entry entry : sdkRequest.headers().entrySet()) {
            for (String value : (List)entry.getValue()) {
                builder.addHeader((String)entry.getKey(), value);
            }
        }
        return builder.build();
    }

    private SdkHttpMethod toSdkMethod(RestRequest.Method method) {
        return switch (method) {
            case RestRequest.Method.GET -> SdkHttpMethod.GET;
            case RestRequest.Method.POST -> SdkHttpMethod.POST;
            case RestRequest.Method.PUT -> SdkHttpMethod.PUT;
            case RestRequest.Method.DELETE -> SdkHttpMethod.DELETE;
            case RestRequest.Method.HEAD -> SdkHttpMethod.HEAD;
            case RestRequest.Method.OPTIONS -> SdkHttpMethod.OPTIONS;
            case RestRequest.Method.PATCH -> SdkHttpMethod.PATCH;
            default -> throw new IllegalArgumentException("Unsupported HTTP method: " + String.valueOf(method));
        };
    }

    private static /* synthetic */ void lambda$signSdkRequest$0(AwsCredentials credentials, SdkHttpFullRequest request, String serviceName, String region, SignRequest.Builder r) {
        ((SignRequest.Builder)((SignRequest.Builder)((SignRequest.Builder)((SignRequest.Builder)r.identity((Identity)credentials)).request((SdkHttpRequest)request)).payload((Object)request.contentStreamProvider().orElse(null))).putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, (Object)serviceName)).putProperty(AwsV4HttpSigner.REGION_NAME, (Object)region);
    }
}

