/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.ensemble;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

public final class FullyWeightedVotingCombiner
implements EnsembleCombiner<Label> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;

    public static FullyWeightedVotingCombiner deserializeFromProto(int version, String className, Any message) {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        if (message.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new FullyWeightedVotingCombiner();
    }

    public EnsembleCombinerProto serialize() {
        EnsembleCombinerProto.Builder combinerProto = EnsembleCombinerProto.newBuilder();
        combinerProto.setClassName(this.getClass().getName());
        combinerProto.setVersion(0);
        return combinerProto.build();
    }

    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
        int numPredictions = predictions.size();
        int numUsed = 0;
        double weight = 1.0 / (double)numPredictions;
        double sum = 0.0;
        double[] score = new double[outputInfo.size()];
        for (Prediction<Label> p : predictions) {
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            for (Label e : p.getOutputScores().values()) {
                double curScore = weight * e.getScore();
                sum += curScore;
                int n = outputInfo.getID((Output)e);
                score[n] = score[n] + curScore;
            }
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predictionMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < score.length; ++i) {
            String name = ((Label)outputInfo.getOutput(i)).getLabel();
            Label label = new Label(name, score[i] / sum);
            predictionMap.put(name, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        Example example = predictions.get(0).getExample();
        return new Prediction(maxLabel, predictionMap, numUsed, example, true);
    }

    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
        if (predictions.size() != weights.length) {
            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()=" + predictions.size() + ", weights.length=" + weights.length);
        }
        int numUsed = 0;
        double sum = 0.0;
        double[] score = new double[outputInfo.size()];
        for (int i = 0; i < weights.length; ++i) {
            Prediction<Label> p = predictions.get(i);
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            for (Label e : p.getOutputScores().values()) {
                double curScore = (double)weights[i] * e.getScore();
                sum += curScore;
                int n = outputInfo.getID((Output)e);
                score[n] = score[n] + curScore;
            }
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predictionMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < score.length; ++i) {
            String name = ((Label)outputInfo.getOutput(i)).getLabel();
            Label label = new Label(name, score[i] / sum);
            predictionMap.put(name, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        Example example = predictions.get(0).getExample();
        return new Prediction(maxLabel, predictionMap, numUsed, example, true);
    }

    public String toString() {
        return "FullyWeightedVotingCombiner()";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "EnsembleCombiner");
    }

    public Class<Label> getTypeWitness() {
        return Label.class;
    }

    public ONNXNode exportCombiner(ONNXNode input) {
        HashMap<String, Object> attributes = new HashMap<String, Object>();
        attributes.put("axes", new int[]{2});
        attributes.put("keepdims", 0);
        return input.apply((ONNXOperator)ONNXOperators.REDUCE_MEAN, attributes);
    }

    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
        ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0L, 1L});
        ONNXNode unsqueezed = weight.apply((ONNXOperator)ONNXOperators.UNSQUEEZE, (ONNXRef)unsqueezeAxes);
        ONNXNode mulByWeights = input.apply((ONNXOperator)ONNXOperators.MUL, (ONNXRef)unsqueezed);
        ONNXNode weightSum = weight.apply((ONNXOperator)ONNXOperators.REDUCE_SUM);
        ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2L});
        return mulByWeights.apply((ONNXOperator)ONNXOperators.REDUCE_SUM, (ONNXRef)sumAxes, Collections.singletonMap("keepdims", 0)).apply((ONNXOperator)ONNXOperators.DIV, (ONNXRef)weightSum);
    }

    public boolean equals(Object o) {
        return o instanceof FullyWeightedVotingCombiner;
    }

    public int hashCode() {
        return 31;
    }
}

