/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.kmeans;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.kmeans.KMeansTrainer;
import org.tribuo.clustering.kmeans.protos.KMeansModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

public class KMeansModel
extends Model<ClusterID> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final DenseVector[] centroidVectors;
    @Deprecated
    private KMeansTrainer.Distance distanceType;
    private Distance dist;

    KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, Distance dist) {
        super(name, description, featureIDMap, outputIDInfo, false);
        this.centroidVectors = centroidVectors;
        this.dist = dist;
    }

    public static KMeansModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        KMeansModelProto proto = (KMeansModelProto)message.unpack(KMeansModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        ImmutableFeatureMap featureDomain = carrier.featureDomain();
        if (proto.getCentroidVectorsCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, no centroids were found");
        }
        DenseVector[] centroids = new DenseVector[proto.getCentroidVectorsCount()];
        List<TensorProto> centroidProtos = proto.getCentroidVectorsList();
        for (int i = 0; i < centroids.length; ++i) {
            DenseVector centroid;
            Tensor centroidTensor = Tensor.deserialize((TensorProto)centroidProtos.get(i));
            if (centroidTensor instanceof DenseVector) {
                centroid = (DenseVector)centroidTensor;
                if (centroid.size() != featureDomain.size()) {
                    throw new IllegalStateException("Invalid protobuf, centroid did not contain all the features, found " + centroid.size() + " expected " + featureDomain.size());
                }
            } else {
                throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + centroidTensor.getClass());
            }
            centroids[i] = centroid;
        }
        Distance dist = (Distance)ProtoUtil.deserialize((Message)proto.getDistance());
        return new KMeansModel(carrier.name(), carrier.provenance(), featureDomain, (ImmutableOutputInfo<ClusterID>)outputDomain, centroids, dist);
    }

    public DenseVector[] getCentroidVectors() {
        DenseVector[] copies = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < copies.length; ++i) {
            copies[i] = this.centroidVectors[i].copy();
        }
        return copies;
    }

    public List<List<Feature>> getCentroids() {
        ArrayList<List<Feature>> output = new ArrayList<List<Feature>>(this.centroidVectors.length);
        for (int i = 0; i < this.centroidVectors.length; ++i) {
            ArrayList<Feature> features = new ArrayList<Feature>(this.featureIDMap.size());
            for (VectorTuple v : this.centroidVectors[i]) {
                Feature f = new Feature(this.featureIDMap.get(v.index).getName(), v.value);
                features.add(f);
            }
            output.add(features);
        }
        return output;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        Object vector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false) : SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double minDistance = Double.POSITIVE_INFINITY;
        int id = -1;
        for (int i = 0; i < this.centroidVectors.length; ++i) {
            double distance = this.dist.computeDistance((SGDVector)this.centroidVectors[i], (SGDVector)vector);
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            id = i;
        }
        return new Prediction((Output)new ClusterID(id), vector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        KMeansModelProto.Builder modelBuilder = KMeansModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setDistance((DistanceProto)this.dist.serialize());
        for (DenseVector e : this.centroidVectors) {
            modelBuilder.addCentroidVectors(e.serialize());
        }
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(KMeansModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    protected KMeansModel copy(String newName, ModelProvenance newProvenance) {
        DenseVector[] newCentroids = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < this.centroidVectors.length; ++i) {
            newCentroids[i] = this.centroidVectors[i].copy();
        }
        return new KMeansModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<ClusterID>)this.outputIDInfo, newCentroids, this.dist);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        if (this.dist == null) {
            this.dist = this.distanceType.getDistanceType().getDistance();
        }
    }
}

