/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common;

import io.skylite.SkyliteException;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.MLClassMappingProvider;
import io.skylite.ml.common.annotation.Connector;
import io.skylite.ml.common.annotation.ExecuteInput;
import io.skylite.ml.common.annotation.ExecuteOutput;
import io.skylite.ml.common.annotation.InputDataSet;
import io.skylite.ml.common.annotation.MLAlgoOutput;
import io.skylite.ml.common.annotation.MLAlgoParameter;
import io.skylite.ml.common.annotation.MLInput;
import io.skylite.ml.common.dataset.MLInputDataType;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.MLOutputType;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

public class MLCommonsClassLoader {
    private static final Logger log = LogManager.getLogger(MLCommonsClassLoader.class);
    private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap();
    private static Map<String, Class<?>> connectorClassMap = new HashMap();

    public static void loadClassMapping(Map<String, LinkedHashSet<String>> namespaces) {
        ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            Thread.currentThread().setContextClassLoader(MLCommonsClassLoader.class.getClassLoader());
            MLCommonsClassLoader.loadMLAlgoParameterClassMapping(namespaces.get("ml_algo_parameter"));
            MLCommonsClassLoader.loadMLOutputClassMapping(namespaces.get("ml_output"));
            MLCommonsClassLoader.loadMLInputDataSetClassMapping(namespaces.get("ml_input_dataset"));
            MLCommonsClassLoader.loadExecuteInputClassMapping(namespaces.get("ml_execute_input"));
            MLCommonsClassLoader.loadExecuteOutputClassMapping(namespaces.get("ml_execute_output"));
            MLCommonsClassLoader.loadMLInputClassMapping(namespaces.get("ml_input"));
            MLCommonsClassLoader.loadConnectorClassMapping(namespaces.get("ml_connector"));
        }
        finally {
            Thread.currentThread().setContextClassLoader(originalClassLoader);
        }
    }

    private static void loadConnectorClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(Connector.class);
            for (Class clazz : classes) {
                String name;
                Connector connector = clazz.getAnnotation(Connector.class);
                if (connector == null || (name = connector.value()) == null || name.length() <= 0) continue;
                connectorClassMap.put(name, clazz);
            }
        }
    }

    private static void loadMLAlgoParameterClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
            for (Class clazz : classes) {
                FunctionName[] algorithms;
                MLAlgoParameter mlAlgoParameter = clazz.getAnnotation(MLAlgoParameter.class);
                if (mlAlgoParameter == null || (algorithms = mlAlgoParameter.algorithms()) == null || algorithms.length <= 0) continue;
                for (FunctionName name : algorithms) {
                    parameterClassMap.put(name, clazz);
                }
            }
            classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
            for (Class clazz : classes) {
                MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
                MLOutputType mlOutputType = mlAlgoOutput.value();
                if (mlOutputType == null) continue;
                parameterClassMap.put(mlOutputType, clazz);
            }
        }
    }

    private static void loadMLOutputClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
            for (Class clazz : classes) {
                MLOutputType mlOutputType;
                MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
                if (mlAlgoOutput == null || (mlOutputType = mlAlgoOutput.value()) == null) continue;
                parameterClassMap.put(mlOutputType, clazz);
            }
        }
    }

    private static void loadMLInputDataSetClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(InputDataSet.class);
            for (Class clazz : classes) {
                MLInputDataType value;
                InputDataSet inputDataSet = clazz.getAnnotation(InputDataSet.class);
                if (inputDataSet == null || (value = inputDataSet.value()) == null) continue;
                parameterClassMap.put(value, clazz);
            }
        }
    }

    private static void loadExecuteInputClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(ExecuteInput.class);
            for (Class clazz : classes) {
                FunctionName[] algorithms;
                ExecuteInput executeInput = clazz.getAnnotation(ExecuteInput.class);
                if (executeInput == null || (algorithms = executeInput.algorithms()) == null || algorithms.length <= 0) continue;
                for (FunctionName name : algorithms) {
                    executeInputClassMap.put(name, clazz);
                }
            }
        }
    }

    private static void loadExecuteOutputClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(ExecuteOutput.class);
            for (Class clazz : classes) {
                FunctionName[] algorithms;
                ExecuteOutput executeOutput = clazz.getAnnotation(ExecuteOutput.class);
                if (executeOutput == null || (algorithms = executeOutput.algorithms()) == null || algorithms.length <= 0) continue;
                for (FunctionName name : algorithms) {
                    executeOutputClassMap.put(name, clazz);
                }
            }
        }
    }

    private static void loadMLInputClassMapping(LinkedHashSet<String> namespaces) {
        for (String namespace : namespaces) {
            Reflections reflections = new Reflections(namespace, new Scanner[0]);
            Set classes = reflections.getTypesAnnotatedWith(MLInput.class);
            for (Class clazz : classes) {
                FunctionName[] algorithms;
                MLInput mlInput = clazz.getAnnotation(MLInput.class);
                if (mlInput == null || (algorithms = mlInput.functionNames()) == null || algorithms.length <= 0) continue;
                for (FunctionName name : algorithms) {
                    mlInputClassMap.put(name, clazz);
                }
            }
        }
    }

    public static <T extends Enum<T>, S, I> S initMLInstance(T type, I in, Class<?> constructorParamClass) {
        return MLCommonsClassLoader.init(parameterClassMap, type, in, constructorParamClass);
    }

    public static <T extends Enum<T>, S, I> S initExecuteInputInstance(T type, I in, Class<?> constructorParamClass) {
        try {
            return MLCommonsClassLoader.init(executeInputClassMap, type, in, constructorParamClass);
        }
        catch (Exception e) {
            return MLCommonsClassLoader.init(mlInputClassMap, type, in, constructorParamClass);
        }
    }

    public static <T extends Enum<T>, S, I> S initExecuteOutputInstance(T type, I in, Class<?> constructorParamClass) {
        try {
            return MLCommonsClassLoader.init(executeOutputClassMap, type, in, constructorParamClass);
        }
        catch (Exception e) {
            if (in instanceof StreamInput) {
                try {
                    return (S)MLOutput.fromStream((StreamInput)in);
                }
                catch (IOException ex) {
                    throw new RuntimeException(ex);
                }
            }
            throw e;
        }
    }

    private static <T, S, I> S init(Map<T, Class<?>> map, T type, I in, Class<?> constructorParamClass) {
        Class<?> clazz = map.get(type);
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + String.valueOf(type));
        }
        try {
            Constructor<?> constructor = clazz.getConstructor(constructorParamClass);
            return (S)constructor.newInstance(in);
        }
        catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException || cause instanceof IllegalArgumentException) {
                throw (RuntimeException)cause;
            }
            log.error("Failed to init instance for type " + String.valueOf(type), (Throwable)e);
            return null;
        }
    }

    public static boolean canInitMLInput(FunctionName functionName) {
        return mlInputClassMap.containsKey((Object)functionName);
    }

    public static <S> S initConnector(String name, Object[] initArgs, Class<?> ... constructorParameterTypes) {
        return MLCommonsClassLoader.init(connectorClassMap, name, initArgs, constructorParameterTypes);
    }

    public static <T extends Enum<T>, S> S initMLInput(T type, Object[] initArgs, Class<?> ... constructorParameterTypes) {
        return MLCommonsClassLoader.init(mlInputClassMap, type, initArgs, constructorParameterTypes);
    }

    private static <T, S> S init(Map<T, Class<?>> map, T type, Object[] initArgs, Class<?> ... constructorParameterTypes) {
        Class<?> clazz = map.get(type);
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + String.valueOf(type));
        }
        try {
            Constructor<?> constructor = clazz.getConstructor(constructorParameterTypes);
            return (S)constructor.newInstance(initArgs);
        }
        catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw (MLException)cause;
            }
            if (cause instanceof IllegalArgumentException) {
                throw (IllegalArgumentException)cause;
            }
            log.error("Failed to init instance for type " + String.valueOf(type), (Throwable)e);
            return null;
        }
    }

    public static void bootstrapForTesting() {
    }

    static {
        Map namespaceMapping = Map.of("ml_algo_parameter", new LinkedHashSet(), "ml_output", new LinkedHashSet(), "ml_input_dataset", new LinkedHashSet(), "ml_execute_input", new LinkedHashSet(), "ml_execute_output", new LinkedHashSet(), "ml_input", new LinkedHashSet(), "ml_connector", new LinkedHashSet());
        try {
            ServiceLoader.load(MLClassMappingProvider.class, MLClassMappingProvider.class.getClassLoader()).forEach(provider -> {
                ((LinkedHashSet)namespaceMapping.get("ml_algo_parameter")).addAll(provider.getMLAlgoParameterNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_output")).addAll(provider.getMLOutputNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_input")).addAll(provider.getMLInputNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_execute_output")).addAll(provider.getMLExecuteOutputNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_execute_input")).addAll(provider.getMLExecuteInputNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_input_dataset")).addAll(provider.getMLInputDatasetNamespaces());
                ((LinkedHashSet)namespaceMapping.get("ml_connector")).addAll(provider.getMLConnectorNamespaces());
            });
        }
        catch (Exception e) {
            throw new SkyliteException("Error loading MLClassMappingProviders ", (Throwable)e, new Object[0]);
        }
        try {
            AccessController.doPrivileged(() -> {
                MLCommonsClassLoader.loadClassMapping(namespaceMapping);
                return null;
            });
        }
        catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML commons", e);
        }
    }
}

