/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.remote;

import io.skylite.SkyliteExceptionsHelper;
import io.skylite.SkyliteStatusException;
import io.skylite.common.TokenBucket;
import io.skylite.common.action.ActionListener;
import io.skylite.common.collect.Tuple;
import io.skylite.common.unit.TimeValue;
import io.skylite.core.action.RetryableAction;
import io.skylite.core.client.Client;
import io.skylite.core.client.ReleasableSkyliteClient;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.script.ScriptService;
import io.skylite.core.security.auth.User;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorClientConfig;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLGuard;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.remote.ConnectorUtils;
import io.skylite.ml.common.remote.ExecutionContext;
import io.skylite.ml.common.remote.RemoteConnectorExecutor;
import io.skylite.ml.common.remote.RemoteConnectorThrottlingException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

public abstract class AbstractConnectorExecutor
implements RemoteConnectorExecutor {
    protected Client client;
    private ConnectorClientConfig connectorClientConfig;

    public void initialize(Connector connector) {
        this.connectorClientConfig = connector.getConnectorClientConfig() != null ? connector.getConnectorClientConfig() : new ConnectorClientConfig();
    }

    public Client getClient() {
        return this.client;
    }

    public void setClient(ReleasableSkyliteClient client) {
        this.client = (Client)client;
    }

    public ConnectorClientConfig getConnectorClientConfig() {
        return this.connectorClientConfig;
    }

    public void setConnectorClientConfig(ConnectorClientConfig connectorClientConfig) {
        this.connectorClientConfig = connectorClientConfig;
    }

    public void invokeRemoteServiceWithRetry(final String action, final MLInput mlInput, final Map<String, String> parameters, final String payload, final ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        RetryableAction<Tuple<Integer, ModelTensors>> invokeRemoteModelAction = new RetryableAction<Tuple<Integer, ModelTensors>>(this.getLogger(), this.getClient().threadPool(), TimeValue.timeValueMillis((long)this.getConnectorClientConfig().getRetryBackoffMillis().intValue()), TimeValue.timeValueSeconds((long)this.getConnectorClientConfig().getRetryTimeoutSeconds().intValue()), actionListener, this.getRetryBackoffPolicy(this.getConnectorClientConfig()), "lucenia_ml_predict_remote"){
            int retryTimes;
            {
                super(logger, threadPool, initialDelay, timeoutValue, listener, backoffPolicy, executor);
                this.retryTimes = 0;
            }

            public void tryAction(ActionListener<Tuple<Integer, ModelTensors>> listener) {
                AbstractConnectorExecutor.this.invokeRemoteService(action, mlInput, parameters, payload, executionContext, listener);
            }

            public boolean shouldRetry(Exception e) {
                Throwable cause = SkyliteExceptionsHelper.unwrapCause((Throwable)e);
                Integer maxRetryTimes = AbstractConnectorExecutor.this.getConnectorClientConfig().getMaxRetryTimes();
                boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException;
                if (++this.retryTimes > maxRetryTimes && maxRetryTimes != -1) {
                    shouldRetry = false;
                }
                if (shouldRetry) {
                    AbstractConnectorExecutor.this.getLogger().debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", this.retryTimes), (Throwable)e);
                }
                return shouldRetry;
            }
        };
        invokeRemoteModelAction.run();
    }

    public void preparePayloadAndInvoke(String action, MLInput mlInput, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        Connector connector = this.getConnector();
        HashMap<String, String> parameters = new HashMap<String, String>();
        if (connector.getParameters() != null) {
            parameters.putAll(connector.getParameters());
        }
        MLInputDataset inputDataset = mlInput.getInputDataset();
        HashMap inputParameters = new HashMap();
        if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet)inputDataset).getParameters() != null) {
            ConnectorUtils.escapeRemoteInferenceInputData((RemoteInferenceInputDataSet)((RemoteInferenceInputDataSet)inputDataset));
            inputParameters.putAll(((RemoteInferenceInputDataSet)inputDataset).getParameters());
        }
        parameters.putAll(inputParameters);
        RemoteInferenceInputDataSet inputData = ConnectorUtils.processInput((String)action, (MLInput)mlInput, (Connector)connector, parameters, (ScriptService)this.getScriptService());
        if (inputData.getParameters() != null) {
            parameters.putAll(inputData.getParameters());
        }
        parameters.putAll(inputParameters);
        String payload = (String)connector.createPayload(action, parameters);
        if (!Boolean.parseBoolean(parameters.getOrDefault("skip_validating_missing_parameters", "false"))) {
            connector.validatePayload(payload);
        }
        String userStr = (String)this.getClient().threadPool().getThreadContext().getTransient("_opendistro_security_user_info");
        User user = User.parse((String)userStr);
        if (this.getRateLimiter() != null && !this.getRateLimiter().request()) {
            this.getLogger().error("Request is throttled at model level.");
            throw new SkyliteStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (user != null && this.getUserRateLimiterMap() != null && this.getUserRateLimiterMap().get(user.getName()) != null && !((TokenBucket)this.getUserRateLimiterMap().get(user.getName())).request()) {
            this.getLogger().error("Request is throttled at user level.");
            throw new SkyliteStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (this.getMlGuard() != null && !this.getMlGuard().validate(payload, MLGuard.Type.INPUT, parameters).booleanValue()) {
            this.getLogger().error("guardrails triggered for user input");
            throw new IllegalArgumentException("guardrails triggered for user input");
        }
        if (this.getConnectorClientConfig().getMaxRetryTimes() != 0) {
            this.invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener);
        } else {
            this.invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener);
        }
    }
}

