/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.sequence;

import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.impl.DatasetDataCarrier;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.SequenceDatasetProto;
import org.tribuo.protos.core.SequenceExampleProto;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.sequence.SequenceExample;

public abstract class SequenceDataset<T extends Output<T>>
implements Iterable<SequenceExample<T>>,
ProtoSerializable<SequenceDatasetProto>,
Provenancable<DatasetProvenance>,
Serializable {
    private static final Logger logger = Logger.getLogger(SequenceDataset.class.getName());
    private static final long serialVersionUID = 2L;
    protected final OutputFactory<T> outputFactory;
    protected final List<SequenceExample<T>> data = new ArrayList<SequenceExample<T>>();
    protected final String tribuoVersion;
    protected final DataProvenance sourceProvenance;

    protected SequenceDataset(DataProvenance sourceProvenance, OutputFactory<T> outputFactory) {
        this(sourceProvenance, outputFactory, "4.3.1");
    }

    protected SequenceDataset(DataProvenance sourceProvenance, OutputFactory<T> outputFactory, String tribuoVersion) {
        this.sourceProvenance = sourceProvenance;
        this.outputFactory = outputFactory;
        this.tribuoVersion = tribuoVersion;
    }

    public String getSourceDescription() {
        return "SequenceDataset(source=" + this.sourceProvenance.toString() + ")";
    }

    public List<SequenceExample<T>> getData() {
        return Collections.unmodifiableList(this.data);
    }

    public DataProvenance getSourceProvenance() {
        return this.sourceProvenance;
    }

    public abstract Set<T> getOutputs();

    public SequenceExample<T> getExample(int index) {
        if (index < 0 || index >= this.size()) {
            throw new IllegalArgumentException("Example index " + index + " is out of bounds.");
        }
        return this.data.get(index);
    }

    public Dataset<T> getFlatDataset() {
        return new FlatDataset(this);
    }

    public int size() {
        return this.data.size();
    }

    public abstract ImmutableOutputInfo<T> getOutputIDInfo();

    public abstract OutputInfo<T> getOutputInfo();

    public abstract ImmutableFeatureMap getFeatureIDMap();

    public abstract FeatureMap getFeatureMap();

    public OutputFactory<T> getOutputFactory() {
        return this.outputFactory;
    }

    @Override
    public Iterator<SequenceExample<T>> iterator() {
        return this.data.iterator();
    }

    public String toString() {
        return "SequenceDataset(source=" + this.sourceProvenance.toString() + ")";
    }

    public boolean validate(Class<? extends Output<?>> clazz) {
        Set<T> domain = this.getOutputInfo().getDomain();
        boolean output = true;
        for (Output type : domain) {
            output &= clazz.isInstance(type);
        }
        return output;
    }

    public static <T extends Output<T>> SequenceDataset<T> castDataset(SequenceDataset<?> inputDataset, Class<T> outputType) {
        if (inputDataset.validate(outputType)) {
            SequenceDataset<?> castedModel = inputDataset;
            return castedModel;
        }
        throw new ClassCastException("Attempted to cast dataset to " + outputType.getName() + " which is not valid for dataset " + inputDataset.toString());
    }

    public static SequenceDataset<?> deserialize(SequenceDatasetProto sequenceProto) {
        return (SequenceDataset)ProtoUtil.deserialize(sequenceProto);
    }

    public static SequenceDataset<?> deserializeFromFile(Path path) throws IOException {
        try (BufferedInputStream is = new BufferedInputStream(Files.newInputStream(path, new OpenOption[0]));){
            SequenceDataset<?> sequenceDataset = SequenceDataset.deserializeFromStream(is);
            return sequenceDataset;
        }
    }

    public static SequenceDataset<?> deserializeFromStream(InputStream is) throws IOException {
        SequenceDatasetProto proto = SequenceDatasetProto.parseFrom(is);
        return SequenceDataset.deserialize(proto);
    }

    public void serializeToFile(Path path) throws IOException {
        try (BufferedOutputStream os = new BufferedOutputStream(Files.newOutputStream(path, new OpenOption[0]));){
            this.serializeToStream(os);
        }
    }

    public void serializeToStream(OutputStream stream) throws IOException {
        SequenceDatasetProto proto = (SequenceDatasetProto)this.serialize();
        proto.writeTo(stream);
    }

    protected DatasetDataCarrier<T> createDataCarrier(FeatureMap featureMap, OutputInfo<T> outputInfo) {
        String version = this.tribuoVersion == null ? "4.3.1" : this.tribuoVersion;
        return new DatasetDataCarrier<T>(this.sourceProvenance, featureMap, outputInfo, this.outputFactory, Collections.emptyList(), version);
    }

    protected static List<SequenceExample<?>> deserializeExamples(List<SequenceExampleProto> examplesList, Class<?> outputClass, FeatureMap fmap) {
        ArrayList examples = new ArrayList();
        for (SequenceExampleProto e : examplesList) {
            SequenceExample<?> seq = SequenceExample.deserialize(e);
            for (Example<?> example : seq) {
                if (example.getOutput().getClass().equals(outputClass)) {
                    for (Feature f : example) {
                        if (fmap.get(f.getName()) != null) continue;
                        throw new IllegalStateException("Invalid protobuf, feature domain does not contain feature " + f.getName() + " present in an example");
                    }
                    continue;
                }
                throw new IllegalStateException("Invalid protobuf, expected all examples to have output class " + outputClass + ", but found " + example.getOutput().getClass());
            }
            examples.add(seq);
        }
        return examples;
    }

    private static class FlatDataset<T extends Output<T>>
    extends ImmutableDataset<T> {
        private static final long serialVersionUID = 1L;

        FlatDataset(SequenceDataset<T> sequenceDataset) {
            super(sequenceDataset.sourceProvenance, sequenceDataset.outputFactory, sequenceDataset.getFeatureIDMap(), sequenceDataset.getOutputIDInfo());
            for (SequenceExample<T> seq : sequenceDataset) {
                for (Example<T> e : seq) {
                    this.data.add(e);
                }
            }
        }
    }
}

