/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.kernel;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.protos.KernelSVMModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.KernelProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

public class KernelSVMModel
extends Model<Label> {
    private static final long serialVersionUID = 2L;
    public static final int CURRENT_VERSION = 0;
    private final Kernel kernel;
    private final SparseVector[] supportVectors;
    private final DenseMatrix weights;

    KernelSVMModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, Kernel kernel, SparseVector[] supportVectors, DenseMatrix weights) {
        super(name, description, featureIDMap, labelIDMap, false);
        this.kernel = kernel;
        this.supportVectors = supportVectors;
        this.weights = weights;
    }

    public static KernelSVMModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        KernelSVMModelProto proto = (KernelSVMModelProto)message.unpack(KernelSVMModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        SparseVector[] supportVectors = new SparseVector[proto.getSupportVectorsCount()];
        int featureSize = carrier.featureDomain().size() + 1;
        List<TensorProto> supportProtos = proto.getSupportVectorsList();
        for (int i = 0; i < supportProtos.size(); ++i) {
            Tensor tensor = Tensor.deserialize((TensorProto)supportProtos.get(i));
            if (!(tensor instanceof SparseVector)) {
                throw new IllegalStateException("Invalid protobuf, support vector must be a sparse vector, found " + tensor.getClass());
            }
            SparseVector vec = (SparseVector)tensor;
            if (vec.size() != featureSize) {
                throw new IllegalStateException("Invalid protobuf, support vector size must equal feature domain size, found " + vec.size() + ", expected " + featureSize);
            }
            supportVectors[i] = vec;
        }
        Tensor weightTensor = Tensor.deserialize((TensorProto)proto.getWeights());
        if (!(weightTensor instanceof DenseMatrix)) {
            throw new IllegalStateException("Invalid protobuf, weights must be a dense matrix, found " + weightTensor.getClass());
        }
        DenseMatrix weights = (DenseMatrix)weightTensor;
        if (weights.getDimension1Size() != carrier.outputDomain().size()) {
            throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + carrier.outputDomain().size() + ", found " + weights.getDimension1Size());
        }
        if (weights.getDimension2Size() != supportVectors.length) {
            throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + supportVectors.length + ", found " + weights.getDimension2Size());
        }
        Kernel kernel = Kernel.deserialize((KernelProto)proto.getKernel());
        return new KernelSVMModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Label>)outputDomain, kernel, supportVectors, weights);
    }

    public int getNumberOfSupportVectors() {
        return this.supportVectors.length;
    }

    public Prediction<Label> predict(Example<Label> example) {
        SparseVector features = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)true);
        if (features.numActiveElements() == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] scores = new double[this.supportVectors.length];
        for (int i = 0; i < scores.length; ++i) {
            scores[i] = this.kernel.similarity(features, this.supportVectors[i]);
        }
        DenseVector scoreVector = DenseVector.createDenseVector((double[])scores);
        DenseVector prediction = this.weights.leftMultiply((SGDVector)scoreVector);
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < prediction.size(); ++i) {
            String labelName = ((Label)this.outputIDInfo.getOutput(i)).getLabel();
            Label label = new Label(labelName, prediction.get(i));
            predMap.put(labelName, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        return new Prediction(maxLabel, predMap, features.numActiveElements(), example, this.generatesProbabilities);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        return Optional.empty();
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        KernelSVMModelProto.Builder modelBuilder = KernelSVMModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setKernel((KernelProto)this.kernel.serialize());
        modelBuilder.setWeights(this.weights.serialize());
        for (SparseVector v : this.supportVectors) {
            modelBuilder.addSupportVectors(v.serialize());
        }
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(KernelSVMModel.class.getName());
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        return builder.build();
    }

    protected KernelSVMModel copy(String newName, ModelProvenance newProvenance) {
        SparseVector[] vectorCopies = new SparseVector[this.supportVectors.length];
        for (int i = 0; i < vectorCopies.length; ++i) {
            vectorCopies[i] = this.supportVectors[i].copy();
        }
        return new KernelSVMModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo, this.kernel, vectorCopies, new DenseMatrix(this.weights));
    }
}

