/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.question_answering;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslatorContext;
import io.skylite.ml.common.algorithms.SentenceTransformerTranslator;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class QuestionAnsweringTranslator
extends SentenceTransformerTranslator {
    private List<String> tokens;

    public NDList processInput(TranslatorContext ctx, Input input) {
        NDManager manager = ctx.getNDManager();
        String question = input.getAsString(0);
        String context = input.getAsString(1);
        NDList ndList = new NDList();
        Encoding encodings = this.tokenizer.encode(question, context);
        this.tokens = Arrays.asList(encodings.getTokens());
        ctx.setAttachment("encoding", (Object)encodings);
        long[] indices = encodings.getIds();
        long[] attentionMask = encodings.getAttentionMask();
        NDArray indicesArray = manager.create(indices);
        indicesArray.setName("input_ids");
        NDArray attentionMaskArray = manager.create(attentionMask);
        attentionMaskArray.setName("attention_mask");
        ndList.add((Object)indicesArray);
        ndList.add((Object)attentionMaskArray);
        return ndList;
    }

    public Output processOutput(TranslatorContext ctx, NDList list) {
        int endIdx;
        Output output = new Output(200, "OK");
        ArrayList<ModelTensor> outputs = new ArrayList<ModelTensor>();
        NDArray startLogits = (NDArray)list.get(0);
        NDArray endLogits = (NDArray)list.get(1);
        int startIdx = (int)startLogits.argMax().getLong(new long[0]);
        if (startIdx >= (endIdx = (int)endLogits.argMax().getLong(new long[0]))) {
            int tmp = startIdx;
            startIdx = endIdx;
            endIdx = tmp;
        }
        String answer = this.tokenizer.buildSentence(this.tokens.subList(startIdx, endIdx + 1));
        outputs.add(new ModelTensor(null, answer));
        ModelTensors modelTensorOutput = new ModelTensors(outputs);
        output.add(modelTensorOutput.toBytes());
        return output;
    }
}

