/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.model;

import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionResponse;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.LatchedActionListener;
import io.skylite.core.client.Client;
import io.skylite.core.client.ReleasableSkyliteClient;
import io.skylite.core.common.Strings;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.common.io.stream.StreamOutput;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.input.remote.RemoteInferenceMLInput;
import io.skylite.ml.common.model.Guardrail;
import io.skylite.ml.common.model.GuardrailFactory;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensorOutput;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.transport.MLTaskResponse;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskAction;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskRequest;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ModelGuardrail
implements Guardrail {
    private static final Logger log = LogManager.getLogger(ModelGuardrail.class);
    public static final String NAME = "model";
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String RESPONSE_FILTER_FIELD = "response_filter";
    public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex";
    private String modelId;
    private String responseFilter;
    private String responseAccept;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private Pattern regexAcceptPattern;

    public ModelGuardrail(String modelId, String responseFilter, String responseAccept) {
        this.modelId = modelId;
        this.responseFilter = responseFilter;
        this.responseAccept = responseAccept;
    }

    public ModelGuardrail(Map<String, Object> params) {
        this((String)params.get(MODEL_ID_FIELD), (String)params.get(RESPONSE_FILTER_FIELD), (String)params.get(RESPONSE_VALIDATION_REGEX_FIELD));
    }

    public ModelGuardrail(StreamInput input) throws IOException {
        this.modelId = input.readString();
        this.responseFilter = input.readString();
        this.responseAccept = input.readString();
    }

    public String typeName() {
        return NAME;
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.modelId);
        out.writeString(this.responseFilter);
        out.writeString(this.responseAccept);
    }

    private Boolean validateAcceptRegex(String input) {
        Matcher matcher = this.regexAcceptPattern.matcher(input);
        return matcher.matches();
    }

    public Boolean validate(String in, Map<String, String> parameters) {
        String input;
        String string = input = parameters != null ? parameters.get("question") : null;
        if (input == null || input.isEmpty()) {
            log.info("Guardrail request is empty.");
            return true;
        }
        log.info("Guardrail request: {}", (Object)input);
        AtomicBoolean isAccepted = new AtomicBoolean(true);
        ActionListener internalListener = ActionListenerHelper.wrap(predictionResponse -> {
            ModelTensorOutput output = (ModelTensorOutput)predictionResponse.getOutput();
            ModelTensor tensor = (ModelTensor)((ModelTensors)output.getMlModelOutputs().get(0)).getMlModelTensors().get(0);
            String guardrailResponse = Strings.toJson(tensor.getDataAsMap().get("response"));
            log.info("Guardrail response: {}", (Object)guardrailResponse);
            if (!this.validateAcceptRegex(guardrailResponse).booleanValue()) {
                isAccepted.set(false);
            }
        }, e -> log.error("[ModelGuardrail] Failed to get prediction response.", (Throwable)e));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse((ActionResponse)res);
            return predictionResponse;
        });
        CountDownLatch latch = new CountDownLatch(1);
        HashMap<String, String> guardrailModelParams = new HashMap<String, String>();
        guardrailModelParams.put("question", input);
        if (this.responseFilter != null && !this.responseFilter.isEmpty()) {
            guardrailModelParams.put(RESPONSE_FILTER_FIELD, this.responseFilter);
        }
        log.info("Guardrail resFilter: {}", (Object)this.responseFilter);
        MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.modelId, RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()).build());
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, (ActionListener)new LatchedActionListener(actionListener, latch));
        try {
            latch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e2) {
            log.error("[ModelGuardrail] Validation was timeout.", (Throwable)e2);
        }
        return isAccepted.get();
    }

    public void init(NamedXContentRegistry xContentRegistry, ReleasableSkyliteClient client) {
        this.xContentRegistry = xContentRegistry;
        assert (client instanceof Client);
        this.client = (Client)client;
        this.regexAcceptPattern = Pattern.compile(this.responseAccept);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.modelId != null) {
            builder.field(MODEL_ID_FIELD, this.modelId);
        }
        if (this.responseFilter != null) {
            builder.field(RESPONSE_FILTER_FIELD, this.responseFilter);
        }
        if (this.responseAccept != null) {
            builder.field(RESPONSE_VALIDATION_REGEX_FIELD, this.responseAccept);
        }
        builder.endObject();
        return builder;
    }

    public static ModelGuardrail parse(XContentParser parser) throws IOException {
        String modelId = null;
        String responseFilter = null;
        String responseAccept = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block10: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "model_id": {
                    modelId = parser.text();
                    continue block10;
                }
                case "response_filter": {
                    responseFilter = parser.text();
                    continue block10;
                }
                case "response_validation_regex": {
                    responseAccept = parser.text();
                    continue block10;
                }
            }
            parser.skipChildren();
        }
        return ModelGuardrail.builder().modelId(modelId).responseFilter(responseFilter).responseAccept(responseAccept).build();
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListenerHelper.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    public String getModelId() {
        return this.modelId;
    }

    public String getResponseFilter() {
        return this.responseFilter;
    }

    public String getResponseAccept() {
        return this.responseAccept;
    }

    public NamedXContentRegistry getxContentRegistry() {
        return this.xContentRegistry;
    }

    public ReleasableSkyliteClient getClient() {
        return this.client;
    }

    public Pattern getRegexAcceptPattern() {
        return this.regexAcceptPattern;
    }

    public static Builder builder() {
        return new Builder();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ModelGuardrail that = (ModelGuardrail)o;
        return this.modelId.equals(that.modelId) && Objects.equals(this.responseFilter, that.responseFilter) && Objects.equals(this.responseAccept, that.responseAccept) && Objects.equals(this.xContentRegistry, that.xContentRegistry) && Objects.equals(this.client, that.client) && Objects.equals(this.regexAcceptPattern, that.regexAcceptPattern);
    }

    public int hashCode() {
        int result = this.modelId.hashCode();
        result = 31 * result + (this.responseFilter != null ? this.responseFilter.hashCode() : 0);
        result = 31 * result + (this.responseAccept != null ? this.responseAccept.hashCode() : 0);
        result = 31 * result + (this.xContentRegistry != null ? this.xContentRegistry.hashCode() : 0);
        result = 31 * result + (this.client != null ? this.client.hashCode() : 0);
        result = 31 * result + (this.regexAcceptPattern != null ? this.regexAcceptPattern.hashCode() : 0);
        return result;
    }

    public static class Builder {
        private String modelId;
        private String responseFilter;
        private String responseAccept;

        public Builder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        public Builder responseFilter(String responseFilter) {
            this.responseFilter = responseFilter;
            return this;
        }

        public Builder responseAccept(String responseAccept) {
            this.responseAccept = responseAccept;
            return this;
        }

        public ModelGuardrail build() {
            return new ModelGuardrail(this.modelId, this.responseFilter, this.responseAccept);
        }
    }

    public static class ModelGuardrailFactory
    implements GuardrailFactory {
        public String typeName() {
            return ModelGuardrail.NAME;
        }

        public Guardrail newInstance(Map<String, Object> parameters) {
            return new ModelGuardrail(parameters);
        }

        public Guardrail newInstance(StreamInput input) throws IOException {
            return null;
        }
    }
}

