/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.remote;

import io.skylite.SkyliteStatusException;
import io.skylite.common.action.ActionListener;
import io.skylite.common.collect.Tuple;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.script.ScriptService;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.model.MLGuard;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.remote.ConnectorUtils;
import io.skylite.ml.common.remote.ExecutionContext;
import io.skylite.ml.common.remote.RemoteConnectorThrottlingException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;

public class MLAwsSdkAsyncHttpResponseHandler
implements SdkAsyncHttpResponseHandler {
    private static final Logger log = LogManager.getLogger(MLAwsSdkAsyncHttpResponseHandler.class);
    public static final String ERROR_HEADER = "x-error-type";
    private Integer statusCode;
    private final StringBuilder responseBody = new StringBuilder();
    private final ExecutionContext executionContext;
    private final ActionListener<Tuple<Integer, ModelTensors>> actionListener;
    private final Map<String, String> parameters;
    private final Connector connector;
    private final String action;
    private final ScriptService scriptService;
    private final MLGuard mlGuard;

    public MLAwsSdkAsyncHttpResponseHandler(ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener, Map<String, String> parameters, Connector connector, ScriptService scriptService, MLGuard mlGuard, String action) {
        this.executionContext = executionContext;
        this.actionListener = actionListener;
        this.parameters = parameters;
        this.connector = connector;
        this.scriptService = scriptService;
        this.mlGuard = mlGuard;
        this.action = action;
    }

    public void onHeaders(SdkHttpResponse response) {
        this.statusCode = response.statusCode();
        log.debug("Received response with status code: {}", (Object)this.statusCode);
        if (this.statusCode < 200 || this.statusCode > 299) {
            log.error("Received error from remote service with status code {}", (Object)this.statusCode);
            this.handleThrottlingInHeader(response);
        }
    }

    public void onStream(Publisher<ByteBuffer> publisher) {
        publisher.subscribe((Subscriber)new Subscriber<ByteBuffer>(){

            public void onSubscribe(Subscription subscription) {
                subscription.request(Long.MAX_VALUE);
            }

            public void onNext(ByteBuffer byteBuffer) {
                byte[] bytes = new byte[byteBuffer.remaining()];
                byteBuffer.get(bytes);
                MLAwsSdkAsyncHttpResponseHandler.this.responseBody.append(new String(bytes, StandardCharsets.UTF_8));
            }

            public void onError(Throwable throwable) {
                log.error("Error while reading response body", throwable);
                MLAwsSdkAsyncHttpResponseHandler.this.handleFailure(throwable instanceof Exception ? (Exception)throwable : new MLException(throwable.getMessage(), throwable));
            }

            public void onComplete() {
                MLAwsSdkAsyncHttpResponseHandler.this.processResponse(MLAwsSdkAsyncHttpResponseHandler.this.responseBody.toString());
            }
        });
    }

    public void onError(Throwable error) {
        log.error("Received error from remote service: {}", (Object)error.getMessage());
        this.handleFailure(error instanceof Exception ? (Exception)error : new MLException(error.getMessage(), error));
    }

    public Integer getStatusCode() {
        return this.statusCode;
    }

    public StringBuilder getResponseBody() {
        return this.responseBody;
    }

    public void handleFailure(Exception ex) {
        log.error("Received error from remote service: {}", (Object)ex.getMessage());
        RestStatus status = this.statusCode == null ? RestStatus.INTERNAL_SERVER_ERROR : RestStatus.fromCode((int)this.statusCode);
        String errorMessage = "Error communicating with remote model: " + ex.getMessage();
        this.actionListener.onFailure((Exception)new SkyliteStatusException(errorMessage, status, new Object[0]));
    }

    protected void handleThrottlingInHeader(SdkHttpResponse response) {
        Map headers = response.headers();
        if (headers == null || headers.isEmpty()) {
            log.debug("No headers in response");
            return;
        }
        List errorHeaders = (List)headers.get(ERROR_HEADER);
        if (errorHeaders == null || errorHeaders.isEmpty()) {
            return;
        }
        boolean containsThrottlingException = errorHeaders.stream().anyMatch(str -> str.startsWith("ThrottlingException"));
        if (containsThrottlingException) {
            log.error("Remote server returned throttling error with code: {}", (Object)this.statusCode);
            this.actionListener.onFailure((Exception)new RemoteConnectorThrottlingException("Error from remote service: The request was denied due to remote server throttling. To change the retry policy and behavior, please update the connector client_config.", RestStatus.fromCode((int)this.statusCode), new Object[0]));
        }
    }

    protected void processResponse(String body) {
        if ((body == null || body.isBlank()) && !this.action.equals(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.toString())) {
            log.error("Remote model response body is empty!");
            this.actionListener.onFailure((Exception)new SkyliteStatusException("No response from model", RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        if (this.statusCode < 200 || this.statusCode > 299) {
            log.error("Remote service returned error code: {} with body: {}", (Object)this.statusCode, (Object)body);
            this.actionListener.onFailure((Exception)new SkyliteStatusException("Error from remote service: " + body, RestStatus.fromCode((int)this.statusCode), new Object[0]));
            return;
        }
        if (this.action.equals(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.toString())) {
            ModelTensors tensors = ModelTensors.builder().statusCode(this.statusCode).build();
            tensors.setStatusCode(this.statusCode);
            this.actionListener.onResponse((Object)new Tuple((Object)this.executionContext.getSequence(), (Object)tensors));
            return;
        }
        try {
            ModelTensors tensors = ConnectorUtils.processOutput((String)this.action, (String)body, (Connector)this.connector, (ScriptService)this.scriptService, this.parameters, (MLGuard)this.mlGuard);
            tensors.setStatusCode(this.statusCode);
            this.actionListener.onResponse((Object)new Tuple((Object)this.executionContext.getSequence(), (Object)tensors));
        }
        catch (Exception e) {
            log.error("Failed to process response body: {}", (Object)body);
            this.actionListener.onFailure((Exception)new MLException("Fail to execute " + this.action + " in MLAwsSdkAsyncHttpResponseHandler", (Throwable)e));
        }
    }
}

