/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.sequence;

import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceExample;

public abstract class SequenceModel<T extends Output<T>>
implements ProtoSerializable<SequenceModelProto>,
Provenancable<ModelProvenance>,
Serializable {
    private static final long serialVersionUID = 1L;
    protected String name;
    private final ModelProvenance provenance;
    protected final String provenanceOutput;
    protected final ImmutableFeatureMap featureIDMap;
    protected final ImmutableOutputInfo<T> outputIDMap;

    public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap) {
        this.name = name;
        this.provenance = provenance;
        this.provenanceOutput = provenance.toString();
        this.featureIDMap = featureIDMap;
        this.outputIDMap = outputIDMap;
    }

    public boolean validate(Class<? extends Output<?>> clazz) {
        Set domain = this.outputIDMap.getDomain();
        boolean output = true;
        for (Output type : domain) {
            output &= clazz.isInstance(type);
        }
        return output;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public ModelProvenance getProvenance() {
        return this.provenance;
    }

    public String toString() {
        if (this.name != null && !this.name.isEmpty()) {
            return this.name + " - " + this.provenanceOutput;
        }
        return this.provenanceOutput;
    }

    public ImmutableFeatureMap getFeatureIDMap() {
        return this.featureIDMap;
    }

    public ImmutableOutputInfo<T> getOutputIDInfo() {
        return this.outputIDMap;
    }

    public abstract List<Prediction<T>> predict(SequenceExample<T> var1);

    public List<List<Prediction<T>>> predict(Iterable<SequenceExample<T>> examples) {
        ArrayList<List<Prediction<T>>> predictions = new ArrayList<List<Prediction<T>>>();
        for (SequenceExample<T> example : examples) {
            predictions.add(this.predict(example));
        }
        return predictions;
    }

    public List<List<Prediction<T>>> predict(SequenceDataset<T> examples) {
        ArrayList<List<Prediction<T>>> predictions = new ArrayList<List<Prediction<T>>>();
        for (SequenceExample<T> example : examples) {
            predictions.add(this.predict(example));
        }
        return predictions;
    }

    @Override
    public SequenceModelProto serialize() {
        throw new UnsupportedOperationException("The default implementation of SequenceModel.serialize() must be overridden to support protobuf serialization.");
    }

    public void serializeToFile(Path path) throws IOException {
        try (BufferedOutputStream os = new BufferedOutputStream(Files.newOutputStream(path, new OpenOption[0]));){
            this.serializeToStream(os);
        }
    }

    public void serializeToStream(OutputStream stream) throws IOException {
        SequenceModelProto proto = this.serialize();
        proto.writeTo(stream);
    }

    public static SequenceModel<?> deserialize(SequenceModelProto proto) {
        return (SequenceModel)ProtoUtil.deserialize(proto);
    }

    public static SequenceModel<?> deserializeFromFile(Path path) throws IOException {
        try (BufferedInputStream is = new BufferedInputStream(Files.newInputStream(path, new OpenOption[0]));){
            SequenceModel<?> sequenceModel = SequenceModel.deserializeFromStream(is);
            return sequenceModel;
        }
    }

    public static SequenceModel<?> deserializeFromStream(InputStream is) throws IOException {
        SequenceModelProto proto = SequenceModelProto.parseFrom(is);
        return SequenceModel.deserialize(proto);
    }

    protected ModelDataCarrier<T> createDataCarrier() {
        return new ModelDataCarrier<T>(this.name, this.provenance, this.featureIDMap, this.outputIDMap, false, this.provenance.getTribuoVersion());
    }

    public abstract Map<String, List<Pair<String, Double>>> getTopFeatures(int var1);

    public static <T extends Output<T>> List<T> toMaxLabels(List<Prediction<T>> predictions) {
        return predictions.stream().map(Prediction::getOutput).collect(Collectors.toList());
    }

    public <U extends Output<U>> SequenceModel<U> castModel(Class<U> outputType) {
        if (this.validate(outputType)) {
            SequenceModel castedModel = this;
            return castedModel;
        }
        throw new ClassCastException("Attempted to cast sequence model to " + outputType.getName() + " which is not valid for model " + this.toString());
    }
}

