/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.neighbour.bruteforce;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.Objects;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQueryFactory;
import org.tribuo.math.neighbour.bruteforce.NeighboursBruteForce;
import org.tribuo.math.protos.BruteForceFactoryProto;
import org.tribuo.math.protos.NeighbourFactoryProto;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(version=0, serializedDataClass=BruteForceFactoryProto.class)
public final class NeighboursBruteForceFactory
implements NeighboursQueryFactory {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    @Config(description="The distance function to use.")
    @ProtoSerializableField
    private Distance distance = DistanceType.L2.getDistance();
    @Config(description="The number of threads to use for training.")
    @ProtoSerializableField
    private int numThreads = 1;

    private NeighboursBruteForceFactory() {
    }

    public NeighboursBruteForceFactory(Distance distance, int numThreads) {
        this.distance = distance;
        this.numThreads = numThreads;
        this.postConfig();
    }

    public static NeighboursBruteForceFactory deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        BruteForceFactoryProto queryProto = (BruteForceFactoryProto)message.unpack(BruteForceFactoryProto.class);
        return new NeighboursBruteForceFactory((Distance)ProtoUtil.deserialize((Message)queryProto.getDistance()), queryProto.getNumThreads());
    }

    public NeighbourFactoryProto serialize() {
        return (NeighbourFactoryProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    @Override
    public NeighboursBruteForce createNeighboursQuery(SGDVector[] data) {
        return new NeighboursBruteForce(data, this.distance, this.numThreads);
    }

    @Override
    public Distance getDistance() {
        return this.distance;
    }

    @Override
    public int getNumThreads() {
        return this.numThreads;
    }

    public synchronized void postConfig() {
        if (this.numThreads <= 0) {
            throw new PropertyException("numThreads", "The number of threads must be a number greater than 0.");
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        NeighboursBruteForceFactory that = (NeighboursBruteForceFactory)o;
        return this.numThreads == that.numThreads && this.distance.equals(that.distance);
    }

    public int hashCode() {
        return Objects.hash(this.distance, this.numThreads);
    }
}

