/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.text_embedding;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import io.lucenia.ml.common.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingTranslatorFactory;
import io.lucenia.ml.common.engine.algorithms.text_embedding.ONNXSentenceTransformerTextEmbeddingTranslator;
import io.lucenia.ml.common.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.algorithms.TextEmbeddingModel;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.model.MLModelConfig;
import io.skylite.ml.common.model.TextEmbeddingModelConfig;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

@Function(value=FunctionName.TEXT_EMBEDDING)
public class TextEmbeddingDenseModel
extends TextEmbeddingModel {
    private static final Logger log = LogManager.getLogger(TextEmbeddingModel.class);
    public static final String SENTENCE_EMBEDDING = "sentence_embedding";

    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        TextEmbeddingModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("OnnxRuntime".equals(engine)) {
            return new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMode, normalizeResult, modelType);
        }
        if (transformersType == TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            return new SentenceTransformerTextEmbeddingTranslator();
        }
        return null;
    }

    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        TextEmbeddingModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("PyTorch".equals(engine) && transformersType != TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            boolean neuron = false;
            if (transformersType.name().endsWith("_NEURON")) {
                neuron = true;
            }
            return new HuggingfaceTextEmbeddingTranslatorFactory(poolingMode, normalizeResult, modelType, neuron);
        }
        return null;
    }
}

