/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.BaseModelConfig;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.engine.algorithms.TextEmbeddingModel;
import org.opensearch.ml.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingTranslatorFactory;
import org.opensearch.ml.engine.algorithms.text_embedding.ONNXSentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.TEXT_EMBEDDING)
public class TextEmbeddingDenseModel
extends TextEmbeddingModel {
    @Generated
    private static final Logger log = LogManager.getLogger(TextEmbeddingDenseModel.class);
    public static final String SENTENCE_EMBEDDING = "sentence_embedding";

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        BaseModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        BaseModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("OnnxRuntime".equals(engine)) {
            return new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMode, normalizeResult, modelType);
        }
        if (transformersType == BaseModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            return new SentenceTransformerTextEmbeddingTranslator();
        }
        return null;
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        BaseModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
        String modelType = textEmbeddingModelConfig.getModelType();
        BaseModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
        boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
        if ("PyTorch".equals(engine) && transformersType != BaseModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
            boolean neuron = false;
            if (transformersType.name().endsWith("_NEURON")) {
                neuron = true;
            }
            return new HuggingfaceTextEmbeddingTranslatorFactory(poolingMode, normalizeResult, modelType, neuron);
        }
        return null;
    }
}

