/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.search.pipelines.generative.client;

import io.skylite.common.action.ActionFuture;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionResponse;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.support.PlainActionFuture;
import io.skylite.core.client.Client;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.transport.MLTaskResponse;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskAction;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskRequest;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MachineLearningInternalClient {
    private static final Logger log = LogManager.getLogger(MachineLearningInternalClient.class);
    private final Client client;

    public MachineLearningInternalClient(Client client) {
        this.client = client;
    }

    public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
        PlainActionFuture actionFuture = PlainActionFuture.newFuture();
        this.predict(modelId, mlInput, (ActionListener<MLOutput>)actionFuture);
        return actionFuture;
    }

    public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().mlInput(mlInput).modelId(modelId).dispatchTask(true).build();
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)predictionRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
        ActionListener internalListener = ActionListenerHelper.wrap(predictionResponse -> listener.onResponse((Object)predictionResponse.getOutput()), arg_0 -> listener.onFailure(arg_0));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse((ActionResponse)res);
            return predictionResponse;
        });
        return actionListener;
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListenerHelper.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    private void validateMLInput(MLInput mlInput, boolean requireInput) {
        if (mlInput == null) {
            throw new IllegalArgumentException("ML Input can't be null");
        }
        if (requireInput && mlInput.getInputDataset() == null) {
            throw new IllegalArgumentException("input data set can't be null");
        }
    }
}

