/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.engine;

import io.skylite.common.ValidationException;
import java.util.Locale;
import java.util.Map;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibrary;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethod;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;

public abstract class AbstractKNNLibrary
implements KNNLibrary {
    protected final Map<String, KNNMethod> methods;
    protected final String version;

    AbstractKNNLibrary(Map<String, KNNMethod> methods, String version) {
        this.methods = methods;
        this.version = version;
    }

    @Override
    public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
        this.throwIllegalArgOnNonNull(this.validateMethodExists(methodName));
        KNNMethod method = this.methods.get(methodName);
        return method.getKNNLibrarySearchContext();
    }

    @Override
    public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
        String method = knnMethodContext.getMethodComponentContext().getName();
        this.throwIllegalArgOnNonNull(this.validateMethodExists(method));
        KNNMethod knnMethod = this.methods.get(method);
        return knnMethod.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
    }

    @Override
    public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
        String methodName = knnMethodContext.getMethodComponentContext().getName();
        ValidationException validationException = null;
        String invalidErrorMessage = this.validateMethodExists(methodName);
        if (invalidErrorMessage != null) {
            validationException = new ValidationException();
            validationException.addValidationError(invalidErrorMessage);
            return validationException;
        }
        invalidErrorMessage = this.validateDimension(knnMethodContext, knnMethodConfigContext);
        if (invalidErrorMessage != null) {
            validationException = new ValidationException();
            validationException.addValidationError(invalidErrorMessage);
        }
        this.validateSpaceType(knnMethodContext, knnMethodConfigContext);
        ValidationException methodValidation = this.methods.get(methodName).validate(knnMethodContext, knnMethodConfigContext);
        if (methodValidation != null) {
            validationException = validationException == null ? new ValidationException() : validationException;
            validationException.addValidationErrors((Iterable)methodValidation.validationErrors());
        }
        return validationException;
    }

    private void validateSpaceType(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
        if (knnMethodContext == null) {
            return;
        }
        knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType());
    }

    private String validateDimension(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
        if (knnMethodContext == null) {
            return null;
        }
        int dimension = knnMethodConfigContext.getDimension();
        if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) {
            return String.format(Locale.ROOT, "Dimension value cannot be greater than %s for vector with engine: %s", KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine()), knnMethodContext.getKnnEngine().getName());
        }
        if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType() && dimension % 8 != 0) {
            return "Dimension should be multiply of 8 for binary vector data type";
        }
        return null;
    }

    @Override
    public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
        String methodName = knnMethodContext.getMethodComponentContext().getName();
        this.throwIllegalArgOnNonNull(this.validateMethodExists(methodName));
        return this.methods.get(methodName).isTrainingRequired(knnMethodContext);
    }

    private String validateMethodExists(String methodName) {
        KNNMethod method = this.methods.get(methodName);
        if (method == null) {
            return String.format(Locale.ROOT, "Invalid method name: %s", methodName);
        }
        return null;
    }

    private void throwIllegalArgOnNonNull(String errorMessage) {
        if (errorMessage != null) {
            throw new IllegalArgumentException(errorMessage);
        }
    }

    @Override
    public String getVersion() {
        return this.version;
    }
}

