/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.engine;

import io.skylite.Version;
import io.skylite.common.TriFunction;
import io.skylite.common.ValidationException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContextImpl;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.validation.ParameterValidator;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;

public class MethodComponent {
    private final String name;
    private final Map<String, Parameter<?>> parameters;
    private final TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator;
    private final TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
    private final boolean requiresTraining;
    private final Set<VectorDataType> supportedVectorDataTypes;

    private MethodComponent(Builder builder) {
        this.name = builder.name;
        this.parameters = builder.parameters;
        this.knnLibraryIndexingContextGenerator = builder.knnLibraryIndexingContextGenerator;
        this.overheadInKBEstimator = builder.overheadInKBEstimator;
        this.requiresTraining = builder.requiresTraining;
        this.supportedVectorDataTypes = builder.supportedDataTypes;
    }

    public KNNLibraryIndexingContext getKNNLibraryIndexingContext(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) {
        if (this.knnLibraryIndexingContextGenerator == null) {
            HashMap<String, Object> parameterMap = new HashMap<String, Object>();
            parameterMap.put("name", methodComponentContext.getName());
            parameterMap.put("parameters", MethodComponent.getParameterMapWithDefaultsAdded(methodComponentContext, this, knnMethodConfigContext));
            return KNNLibraryIndexingContextImpl.builder().parameters(parameterMap).build();
        }
        return (KNNLibraryIndexingContext)this.knnLibraryIndexingContextGenerator.apply((Object)this, (Object)methodComponentContext, (Object)knnMethodConfigContext);
    }

    public ValidationException validate(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) {
        ValidationException methodValidationException;
        Map<String, Object> providedParameters = methodComponentContext.getParameters();
        ValidationException validationException = null;
        if (!this.supportedVectorDataTypes.contains((Object)knnMethodConfigContext.getVectorDataType())) {
            validationException = new ValidationException();
            validationException.addValidationError(String.format(Locale.ROOT, "Method \"%s\" is not supported for vector data type \"%s\".", new Object[]{this.name, knnMethodConfigContext.getVectorDataType()}));
        }
        if ((methodValidationException = ParameterValidator.validateParameters(this.parameters, providedParameters, knnMethodConfigContext)) != null) {
            validationException = validationException == null ? new ValidationException() : validationException;
            validationException.addValidationErrors((Iterable)methodValidationException.validationErrors());
        }
        return validationException;
    }

    public boolean isTrainingRequired(MethodComponentContext methodComponentContext) {
        if (this.requiresTraining) {
            return true;
        }
        Map<String, Object> providedParameters = methodComponentContext.getParameters();
        if (providedParameters == null || providedParameters.isEmpty()) {
            return false;
        }
        for (Map.Entry<String, Object> providedParameter : providedParameters.entrySet()) {
            MethodComponentContext parameterMethodComponentContext;
            MethodComponent methodComponent;
            Parameter<?> parameter = this.parameters.get(providedParameter.getKey());
            if (!(parameter instanceof Parameter.MethodComponentContextParameter)) continue;
            Parameter.MethodComponentContextParameter methodParameter = (Parameter.MethodComponentContextParameter)parameter;
            Object providedValue = providedParameter.getValue();
            if (!(providedValue instanceof MethodComponentContext) || !(methodComponent = methodParameter.getMethodComponent((parameterMethodComponentContext = (MethodComponentContext)providedValue).getName())).isTrainingRequired(parameterMethodComponentContext)) continue;
            return true;
        }
        return false;
    }

    public int estimateOverheadInKB(MethodComponentContext methodComponentContext, int dimension) {
        long size = (Long)this.overheadInKBEstimator.apply((Object)this, (Object)methodComponentContext, (Object)dimension);
        Map<String, Object> providedParameters = methodComponentContext.getParameters();
        if (providedParameters == null || providedParameters.isEmpty()) {
            return Math.toIntExact(size);
        }
        for (Map.Entry<String, Object> providedParameter : providedParameters.entrySet()) {
            Parameter<?> parameter = this.parameters.get(providedParameter.getKey());
            if (!(parameter instanceof Parameter.MethodComponentContextParameter)) continue;
            Parameter.MethodComponentContextParameter methodParameter = (Parameter.MethodComponentContextParameter)parameter;
            Object providedValue = providedParameter.getValue();
            if (!(providedValue instanceof MethodComponentContext)) continue;
            MethodComponentContext parameterMethodComponentContext = (MethodComponentContext)providedValue;
            MethodComponent methodComponent = methodParameter.getMethodComponent(parameterMethodComponentContext.getName());
            size += (long)methodComponent.estimateOverheadInKB(parameterMethodComponentContext, dimension);
        }
        return Math.toIntExact(size);
    }

    public static Map<String, Object> getParameterMapWithDefaultsAdded(MethodComponentContext methodComponentContext, MethodComponent methodComponent, KNNMethodConfigContext knnMethodConfigContext) {
        HashMap<String, Object> parametersWithDefaultsMap = new HashMap<String, Object>();
        Map<String, Object> userProvidedParametersMap = methodComponentContext.getParameters();
        Version<?> indexCreationVersion = knnMethodConfigContext.getVersionCreated();
        Mode mode = knnMethodConfigContext.getMode();
        CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();
        boolean isOnDiskWithBinaryQuantization = compressionLevel == CompressionLevel.x32 || compressionLevel == CompressionLevel.x16 || compressionLevel == CompressionLevel.x8;
        for (Parameter<?> parameter : methodComponent.getParameters().values()) {
            if (methodComponentContext.getParameters().containsKey(parameter.getName())) {
                parametersWithDefaultsMap.put(parameter.getName(), userProvidedParametersMap.get(parameter.getName()));
                continue;
            }
            if (parameter.getName().equals("ef_search")) {
                if (isOnDiskWithBinaryQuantization) {
                    parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue());
                    continue;
                }
                parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getHNSWEFSearchValue(indexCreationVersion));
                continue;
            }
            if (parameter.getName().equals("ef_construction")) {
                if (isOnDiskWithBinaryQuantization) {
                    parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue());
                    continue;
                }
                parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getHNSWEFConstructionValue(indexCreationVersion));
                continue;
            }
            Object value = parameter.getDefaultValue();
            if (value == null) continue;
            parametersWithDefaultsMap.put(parameter.getName(), value);
        }
        return parametersWithDefaultsMap;
    }

    public String getName() {
        return this.name;
    }

    public Map<String, Parameter<?>> getParameters() {
        return this.parameters;
    }

    public static class Builder {
        private final String name;
        private final Map<String, Parameter<?>> parameters;
        private TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator;
        private TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
        private boolean requiresTraining;
        private final Set<VectorDataType> supportedDataTypes;

        public static Builder builder(String name) {
            return new Builder(name);
        }

        private Builder(String name) {
            this.name = name;
            this.parameters = new HashMap();
            this.overheadInKBEstimator = (mc, mcc, d) -> 0L;
            this.supportedDataTypes = new HashSet<VectorDataType>();
        }

        public Builder addParameter(String parameterName, Parameter<?> parameter) {
            this.parameters.put(parameterName, parameter);
            return this;
        }

        public Builder setKnnLibraryIndexingContextGenerator(TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator) {
            this.knnLibraryIndexingContextGenerator = knnLibraryIndexingContextGenerator;
            return this;
        }

        public Builder setRequiresTraining(boolean requiresTraining) {
            this.requiresTraining = requiresTraining;
            return this;
        }

        public Builder setOverheadInKBEstimator(TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator) {
            this.overheadInKBEstimator = overheadInKBEstimator;
            return this;
        }

        public Builder addSupportedDataTypes(Set<VectorDataType> dataTypeSet) {
            this.supportedDataTypes.addAll(dataTypeSet);
            return this;
        }

        public MethodComponent build() {
            return new MethodComponent(this);
        }
    }
}

