/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.sgd;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.onnx.ONNXMathUtils;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperator;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public abstract class AbstractFMModel<T extends Output<T>>
extends AbstractSGDModel<T> {
    private static final long serialVersionUID = 1L;

    protected AbstractFMModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities) {
        super(name, provenance, featureIDMap, outputIDInfo, parameters, generatesProbabilities, false);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        DenseVector biases = (DenseVector)this.modelParameters.get()[0];
        DenseMatrix baseWeights = (DenseMatrix)this.modelParameters.get()[1];
        int maxFeatures = n < 0 ? this.featureIDMap.size() + 1 : n;
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        int numClasses = baseWeights.getDimension1Size();
        int numFeatures = baseWeights.getDimension2Size();
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        for (int i = 0; i < numClasses; ++i) {
            PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
            for (int j = 0; j < numFeatures; ++j) {
                Pair curr = new Pair((Object)this.featureIDMap.get(j).getName(), (Object)baseWeights.get(i, j));
                if (q.size() < maxFeatures) {
                    q.offer(curr);
                    continue;
                }
                if (comparator.compare(curr, q.peek()) <= 0) continue;
                q.poll();
                q.offer(curr);
            }
            Pair curr = new Pair((Object)"BIAS", (Object)biases.get(i));
            if (q.size() < maxFeatures) {
                q.offer(curr);
            } else if (comparator.compare(curr, q.peek()) > 0) {
                q.poll();
                q.offer(curr);
            }
            ArrayList<Pair> b = new ArrayList<Pair>();
            while (q.size() > 0) {
                b.add(q.poll());
            }
            Collections.reverse(b);
            map.put(this.getDimensionName(i), b);
        }
        return map;
    }

    public DenseMatrix getLinearWeightsCopy() {
        return ((DenseMatrix)this.modelParameters.get()[1]).copy();
    }

    public DenseVector getBiasesCopy() {
        return ((DenseVector)this.modelParameters.get()[0]).copy();
    }

    public Tensor[] getFactorsCopy() {
        Tensor[] params = this.modelParameters.get();
        Tensor[] paramCopy = new Tensor[params.length - 2];
        for (int i = 0; i < paramCopy.length; ++i) {
            paramCopy[i] = params[i + 2].copy();
        }
        return paramCopy;
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    protected abstract String getDimensionName(int var1);

    protected abstract ONNXNode onnxOutput(ONNXNode var1);

    protected abstract String onnxModelName();

    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        Tensor[] modelParams = this.modelParameters.get();
        ONNXInitializer twoConst = onnx.constant("two_const", 2.0f);
        ONNXInitializer sumAxes = onnx.array("sum_over_embedding_axes", new long[]{1L});
        ONNXInitializer weights = ONNXMathUtils.floatMatrix((ONNXContext)onnx, (String)"fm_linear_weights", (Matrix)((Matrix)modelParams[1]), (boolean)true);
        ONNXInitializer bias = ONNXMathUtils.floatVector((ONNXContext)onnx, (String)"fm_biases", (SGDVector)((SGDVector)modelParams[0]));
        ONNXNode gemm = input.apply((ONNXOperator)ONNXOperators.GEMM, Arrays.asList(weights, bias));
        ONNXNode inputSquared = input.apply((ONNXOperator)ONNXOperators.POW, (ONNXRef)twoConst);
        ArrayList<ONNXNode> embeddingOutputs = new ArrayList<ONNXNode>();
        for (int i = 0; i < this.outputIDInfo.size(); ++i) {
            ONNXInitializer embWeight = ONNXMathUtils.floatMatrix((ONNXContext)onnx, (String)("fm_embedding_" + i), (Matrix)((Matrix)modelParams[i + 2]), (boolean)true);
            ONNXNode featureEmbedding = input.apply((ONNXOperator)ONNXOperators.GEMM, (ONNXRef)embWeight);
            ONNXNode embeddingSquared = featureEmbedding.apply((ONNXOperator)ONNXOperators.POW, (ONNXRef)twoConst);
            ONNXNode embWeightSquared = embWeight.apply((ONNXOperator)ONNXOperators.POW, (ONNXRef)twoConst);
            ONNXNode inputByEmbeddingSquared = inputSquared.apply((ONNXOperator)ONNXOperators.GEMM, (ONNXRef)embWeightSquared);
            ONNXNode subtract = embeddingSquared.apply((ONNXOperator)ONNXOperators.SUB, (ONNXRef)inputByEmbeddingSquared);
            embeddingOutputs.add(subtract.apply((ONNXOperator)ONNXOperators.REDUCE_SUM, (ONNXRef)sumAxes).apply((ONNXOperator)ONNXOperators.DIV, (ONNXRef)twoConst));
        }
        ONNXNode concat = onnx.operation((ONNXOperator)ONNXOperators.CONCAT, embeddingOutputs, "fm_concat", Collections.singletonMap("axis", 1));
        return this.onnxOutput(gemm.apply((ONNXOperator)ONNXOperators.ADD, (ONNXRef)concat));
    }

    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        onnx.setName(this.onnxModelName());
        ONNXPlaceholder input = onnx.floatInput("input", this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput("output", this.outputIDInfo.size());
        this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output);
        return ONNXExportable.buildModel((ONNXContext)onnx, (String)domain, (long)modelVersion, (Provenancable)this);
    }
}

