/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.remote;

import io.skylite.common.TokenBucket;
import io.skylite.common.action.ActionListener;
import io.skylite.common.collect.Tuple;
import io.skylite.common.unit.TimeValue;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.GroupedActionListener;
import io.skylite.core.action.bulk.BackoffPolicy;
import io.skylite.core.client.ReleasableSkyliteClient;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.script.ScriptService;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.connector.ConnectorClientConfig;
import io.skylite.ml.common.connector.MLPreProcessFunction;
import io.skylite.ml.common.dataset.TextDocsInputDataSet;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLGuard;
import io.skylite.ml.common.output.model.ModelTensorOutput;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.remote.ExecutionContext;
import io.skylite.ml.common.transport.MLTaskResponse;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.logging.log4j.Logger;

public interface RemoteConnectorExecutor {
    public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
    public static final String RETRY_EXECUTOR = "lucenia_ml_predict_remote";

    default public void executeAction(String action, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener tensorActionListener = ActionListenerHelper.wrap(r -> {
            ModelTensors[] modelTensors = new ModelTensors[r.size()];
            r.forEach(sequenceNoAndModelTensor -> {
                modelTensors[((Integer)sequenceNoAndModelTensor.v1()).intValue()] = (ModelTensors)sequenceNoAndModelTensor.v2();
            });
            actionListener.onResponse((Object)new MLTaskResponse(new ModelTensorOutput(Arrays.asList(modelTensors))));
        }, arg_0 -> actionListener.onFailure(arg_0));
        try {
            if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
                TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset();
                Tuple<Integer, Integer> calculatedChunkSize = this.calculateChunkSize(action, textDocsInputDataSet);
                GroupedActionListener groupedActionListener = new GroupedActionListener(tensorActionListener, ((Integer)calculatedChunkSize.v1()).intValue());
                int sequence = 0;
                for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += ((Integer)calculatedChunkSize.v2()).intValue()) {
                    List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, Math.min(processedDocs + (Integer)calculatedChunkSize.v2(), textDocsInputDataSet.getDocs().size()));
                    this.preparePayloadAndInvoke(action, MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), new ExecutionContext(sequence++), (ActionListener<Tuple<Integer, ModelTensors>>)groupedActionListener);
                }
            } else {
                this.preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), (ActionListener<Tuple<Integer, ModelTensors>>)new GroupedActionListener(tensorActionListener, 1));
            }
        }
        catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) {
        int textDocsLength = textDocsInputDataSet.getDocs().size();
        Map<String, String> parameters = this.getConnector().getParameters();
        if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
            boolean isDivisible;
            int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
            if (stepSize <= 0) {
                throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
            }
            boolean bl = isDivisible = textDocsLength % stepSize == 0;
            if (isDivisible) {
                return Tuple.tuple((Object)(textDocsLength / stepSize), (Object)stepSize);
            }
            return Tuple.tuple((Object)(textDocsLength / stepSize + 1), (Object)stepSize);
        }
        Optional<ConnectorAction> connectorAction = this.getConnector().findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        String preProcessFunction = connectorAction.get().getPreProcessFunction();
        if (preProcessFunction == null) {
            return Tuple.tuple((Object)1, (Object)textDocsLength);
        }
        if ("connector.pre_process.bedrock.embedding".equals(preProcessFunction) || !MLPreProcessFunction.contains(preProcessFunction)) {
            return Tuple.tuple((Object)textDocsLength, (Object)1);
        }
        return Tuple.tuple((Object)1, (Object)textDocsLength);
    }

    default public void setScriptService(ScriptService scriptService) {
    }

    public ScriptService getScriptService();

    public Connector getConnector();

    public TokenBucket getRateLimiter();

    public Map<String, TokenBucket> getUserRateLimiterMap();

    public MLGuard getMlGuard();

    public Logger getLogger();

    public ConnectorClientConfig getConnectorClientConfig();

    default public void setClient(ReleasableSkyliteClient client) {
    }

    default public void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {
    }

    default public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
    }

    default public void setClusterService(ClusterService clusterService) {
    }

    default public void setRateLimiter(TokenBucket rateLimiter) {
    }

    default public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
    }

    default public void setMlGuard(MLGuard mlGuard) {
    }

    default public void preparePayloadAndInvoke(String action, MLInput mlInput, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
    }

    default public BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {
        switch (connectorClientConfig.getRetryBackoffPolicy()) {
            case EXPONENTIAL_EQUAL_JITTER: {
                return BackoffPolicy.exponentialEqualJitterBackoff((long)connectorClientConfig.getRetryBackoffMillis().intValue(), (long)(connectorClientConfig.getRetryTimeoutSeconds() * 1000));
            }
            case EXPONENTIAL_FULL_JITTER: {
                return BackoffPolicy.exponentialFullJitterBackoff((long)connectorClientConfig.getRetryBackoffMillis().intValue());
            }
        }
        return BackoffPolicy.constantBackoff((TimeValue)TimeValue.timeValueMillis((long)connectorClientConfig.getRetryBackoffMillis().intValue()), (int)Integer.MAX_VALUE);
    }

    default public void invokeRemoteServiceWithRetry(String action, MLInput mlInput, Map<String, String> parameters, String payload, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
    }

    public void invokeRemoteService(String var1, MLInput var2, Map<String, String> var3, String var4, ExecutionContext var5, ActionListener<Tuple<Integer, ModelTensors>> var6);
}

