/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.ensemble;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.ensemble.EnsembleExcuse;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.WeightedEnsembleModelProto;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TimestampedTrainerProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public final class WeightedEnsembleModel<T extends Output<T>>
extends EnsembleModel<T>
implements ONNXExportable {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    protected final float[] weights;
    protected final EnsembleCombiner<T> combiner;

    public WeightedEnsembleModel(String name, EnsembleModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels, EnsembleCombiner<T> combiner) {
        this(name, provenance, featureIDMap, outputIDInfo, newModels, combiner, Util.generateUniformVector(newModels.size(), 1.0f / (float)newModels.size()));
    }

    public WeightedEnsembleModel(String name, EnsembleModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels, EnsembleCombiner<T> combiner, float[] weights) {
        super(name, provenance, featureIDMap, outputIDInfo, newModels);
        this.weights = Arrays.copyOf(weights, weights.length);
        this.combiner = combiner;
    }

    public static WeightedEnsembleModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        EnsembleCombiner<?> combiner;
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        WeightedEnsembleModelProto proto = (WeightedEnsembleModelProto)message.unpack(WeightedEnsembleModelProto.class);
        ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
        ModelProvenance prov = carrier.provenance();
        if (!(prov instanceof EnsembleModelProvenance)) {
            throw new IllegalStateException("Invalid protobuf, the provenance was not an EnsembleModelProvenance. Found " + prov);
        }
        EnsembleModelProvenance ensembleProvenance = (EnsembleModelProvenance)prov;
        ImmutableOutputInfo<?> outputDomain = carrier.outputDomain();
        Class<?> outputClass = outputDomain.getOutput(0).getClass();
        if (!outputClass.equals((combiner = EnsembleCombiner.deserialize(proto.getCombiner())).getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, combiner and output domain have a type mismatch, expected " + outputClass + " found " + combiner.getTypeWitness());
        }
        if (proto.getModelsCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, no models were found in the ensemble");
        }
        if (proto.getModelsCount() != proto.getWeightsCount()) {
            throw new IllegalStateException("Invalid protobuf, different numbers of models and weights were found, " + proto.getModelsCount() + " models, " + proto.getWeightsCount() + " weights");
        }
        ArrayList models = new ArrayList(proto.getModelsCount());
        for (ModelProto p : proto.getModelsList()) {
            Model<?> model = Model.deserialize(p);
            if (model.validate(outputClass)) {
                models.add(model);
                continue;
            }
            throw new IllegalStateException("Invalid protobuf, output type of model '" + model.toString() + "' did not match expected " + outputClass);
        }
        float[] weights = Util.toPrimitiveFloat(proto.getWeightsList());
        return new WeightedEnsembleModel(carrier.name(), ensembleProvenance, carrier.featureDomain(), outputDomain, models, combiner, weights);
    }

    @Override
    public Prediction<T> predict(Example<T> example) {
        ArrayList predictions = new ArrayList();
        for (Model model : this.models) {
            predictions.add(model.predict(example));
        }
        return this.combiner.combine(this.outputIDInfo, predictions, this.weights);
    }

    @Override
    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        HashMap<String, Map> map = new HashMap<String, Map>();
        Prediction<T> prediction = this.predict(example);
        ArrayList excuses = new ArrayList();
        for (int i = 0; i < this.models.size(); ++i) {
            Optional<Excuse<T>> excuse = ((Model)this.models.get(i)).getExcuse(example);
            if (!excuse.isPresent()) continue;
            excuses.add(excuse.get());
            Map<String, List<Pair<String, Double>>> m = excuse.get().getScores();
            for (Map.Entry<String, List<Pair<String, Double>>> e : m.entrySet()) {
                Map innerMap = map.computeIfAbsent(e.getKey(), k -> new HashMap());
                for (Pair<String, Double> p : e.getValue()) {
                    innerMap.merge((String)p.getA(), (Double)p.getB() * (double)this.weights[i], Double::sum);
                }
            }
        }
        if (map.isEmpty()) {
            return Optional.empty();
        }
        HashMap<String, List<Pair<String, Double>>> outputMap = new HashMap<String, List<Pair<String, Double>>>();
        for (Map.Entry label : map.entrySet()) {
            ArrayList<Pair> list = new ArrayList<Pair>();
            for (Map.Entry entry : ((Map)label.getValue()).entrySet()) {
                list.add(new Pair((Object)((String)entry.getKey()), (Object)((Double)entry.getValue())));
            }
            list.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            outputMap.put((String)label.getKey(), list);
        }
        return Optional.of(new EnsembleExcuse<T>(example, prediction, outputMap, excuses));
    }

    @Override
    protected EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels) {
        return new WeightedEnsembleModel<T>(name, newProvenance, this.featureIDMap, this.outputIDInfo, newModels, this.combiner);
    }

    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner) {
        return WeightedEnsembleModel.createEnsembleFromExistingModels(name, models, combiner, Util.generateUniformVector(models.size(), 1.0f / (float)models.size()));
    }

    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner, float[] weights) {
        if (models.size() < 2) {
            throw new IllegalArgumentException("Must supply at least 2 models, found " + models.size());
        }
        if (weights.length != models.size()) {
            throw new IllegalArgumentException("Must supply one weight per model, models.size() = " + models.size() + ", weights.length = " + weights.length);
        }
        ImmutableOutputInfo<T> outputInfo = models.get(0).getOutputIDInfo();
        ImmutableFeatureMap featureMap = models.get(0).getFeatureIDMap();
        Set firstOutputDomain = outputInfo.getDomain();
        for (int i = 1; i < models.size(); ++i) {
            if (!models.get(i).getOutputIDInfo().getDomain().equals(firstOutputDomain)) {
                throw new IllegalArgumentException("Model output domains are not equal.");
            }
            if (models.get(i).getFeatureIDMap().domainEquals(featureMap)) continue;
            throw new IllegalArgumentException("Model feature domains are not equal.");
        }
        ArrayList<Model<T>> modelList = new ArrayList<Model<T>>(models);
        TimestampedTrainerProvenance trainerProvenance = new TimestampedTrainerProvenance();
        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), models.get(0).getProvenance().getDatasetProvenance(), (TrainerProvenance)trainerProvenance, (ListProvenance<? extends ModelProvenance>)ListProvenance.createListProvenance(models));
        return new WeightedEnsembleModel<T>(name, provenance, featureMap, outputInfo, modelList, combiner, weights);
    }

    @Override
    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        onnx.setName("WeightedEnsembleModel");
        ONNXPlaceholder input = onnx.floatInput(this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput(this.outputIDInfo.size());
        this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output);
        return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
    }

    @Override
    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        ONNXInitializer unsqueezeAxes = onnx.array("unsqueeze_ensemble_output", new long[]{2L});
        ArrayList<ONNXNode> unsqueezedMembers = new ArrayList<ONNXNode>();
        for (Model model : this.models) {
            if (model instanceof ONNXExportable) {
                ONNXNode memberOutput = ((ONNXExportable)((Object)model)).writeONNXGraph(input);
                ONNXNode unsqueezedOutput = memberOutput.apply((ONNXOperator)ONNXOperators.UNSQUEEZE, (ONNXRef)unsqueezeAxes);
                if (model.getOutputIDInfo().domainAndIDEquals(this.outputIDInfo)) {
                    unsqueezedMembers.add(unsqueezedOutput);
                    continue;
                }
                int[] outputRemapping = new int[this.outputIDInfo.size()];
                int i = 0;
                while (i < outputRemapping.length) {
                    int otherId = this.outputIDInfo.getID(model.getOutputIDInfo().getOutput(i));
                    outputRemapping[otherId] = i++;
                }
                ONNXInitializer indices = onnx.array("ensemble_output_gather_indices", outputRemapping);
                ONNXNode gatheredOutput = unsqueezedOutput.apply((ONNXOperator)ONNXOperators.GATHER, (ONNXRef)indices, Collections.singletonMap("axis", 1));
                unsqueezedMembers.add(gatheredOutput);
                continue;
            }
            throw new IllegalStateException("Ensemble member '" + model.toString() + "' is not ONNXExportable.");
        }
        ONNXInitializer ensembleWeights = onnx.array("ensemble_weights", this.weights);
        ONNXNode concat = onnx.operation((ONNXOperator)ONNXOperators.CONCAT, unsqueezedMembers, "ensemble_concat", Collections.singletonMap("axis", 2));
        return this.combiner.exportCombiner(concat, ensembleWeights);
    }

    @Override
    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        WeightedEnsembleModelProto.Builder modelBuilder = WeightedEnsembleModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        for (Model m : this.models) {
            modelBuilder.addModels(m.serialize());
        }
        modelBuilder.addAllWeights(Util.toBoxedFloats(this.weights));
        modelBuilder.setCombiner((EnsembleCombinerProto)this.combiner.serialize());
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(WeightedEnsembleModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }
}

