/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.sparse_encoding;

import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.TranslatorContext;
import io.skylite.ml.common.algorithms.SentenceTransformerTranslator;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensors;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SparseEncodingTranslator
extends SentenceTransformerTranslator {
    public Output processOutput(TranslatorContext ctx, NDList list) {
        Output output = new Output(200, "OK");
        ArrayList<ModelTensor> outputs = new ArrayList<ModelTensor>();
        for (NDArray ndArray : list) {
            String name = ndArray.getName();
            Map<String, Float> tokenWeightsMap = this.convertOutput(ndArray);
            Map<String, List<Map<String, Float>>> wrappedMap = Map.of("response", Collections.singletonList(tokenWeightsMap));
            ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build();
            outputs.add(tensor);
        }
        ModelTensors modelTensorOutput = new ModelTensors(outputs);
        output.add(modelTensorOutput.toBytes());
        return output;
    }

    private Map<String, Float> convertOutput(NDArray array) {
        HashMap<String, Float> map = new HashMap<String, Float>();
        NDArray nonZeroIndices = array.nonzero().squeeze();
        for (long index : nonZeroIndices.toLongArray()) {
            String s = this.tokenizer.decode(new long[]{index}, true);
            if (s.isEmpty()) continue;
            map.put(s, Float.valueOf(array.getFloat(new long[]{index})));
        }
        return map;
    }
}

