/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.dataset;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.function.Predicate;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputInfo;
import org.tribuo.protos.core.DatasetProto;
import org.tribuo.protos.core.DatasetViewProto;
import org.tribuo.protos.core.OutputDomainProto;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.util.Util;

public final class DatasetView<T extends Output<T>>
extends ImmutableDataset<T> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final Dataset<T> innerDataset;
    private final int size;
    private final int[] exampleIndices;
    private final long seed;
    private final String tag;
    private final boolean sampled;
    private final boolean weighted;
    private boolean storeIndices = false;

    public DatasetView(Dataset<T> dataset, int[] exampleIndices, String tag) {
        this(dataset, exampleIndices, dataset.getFeatureIDMap(), dataset.getOutputIDInfo(), tag);
    }

    public DatasetView(Dataset<T> dataset, int[] exampleIndices, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, String tag) {
        super((DataProvenance)dataset.getProvenance(), dataset.getOutputFactory(), featureIDs, labelIDs);
        if (!DatasetView.validateIndices(dataset.size(), exampleIndices)) {
            throw new IllegalArgumentException("Invalid indices supplied, dataset.size() = " + dataset.size() + ", but found a negative index or a value greater than or equal to size.");
        }
        this.innerDataset = dataset;
        this.size = exampleIndices.length;
        this.exampleIndices = exampleIndices;
        this.seed = -1L;
        this.tag = tag;
        this.storeIndices = true;
        this.sampled = false;
        this.weighted = false;
    }

    private DatasetView(Dataset<T> dataset, int[] exampleIndices, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs, boolean weighted) {
        super((DataProvenance)dataset.getProvenance(), dataset.getOutputFactory(), featureIDs, outputIDs);
        this.innerDataset = dataset;
        this.size = exampleIndices.length;
        this.exampleIndices = exampleIndices;
        this.tag = "";
        this.seed = seed;
        this.sampled = true;
        this.weighted = weighted;
        this.storeIndices = weighted;
    }

    private DatasetView(Dataset<T> dataset, int[] exampleIndices, long seed, String tag, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs, boolean sampled, boolean weighted, boolean storeIndices) {
        super((DataProvenance)dataset.getProvenance(), dataset.getOutputFactory(), featureIDs, outputIDs);
        this.innerDataset = dataset;
        this.size = exampleIndices.length;
        this.exampleIndices = exampleIndices;
        this.tag = tag;
        this.seed = seed;
        this.sampled = sampled;
        this.weighted = weighted;
        this.storeIndices = storeIndices;
    }

    public static DatasetView<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        DatasetViewProto proto = (DatasetViewProto)message.unpack(DatasetViewProto.class);
        Dataset<?> inner = Dataset.deserialize(proto.getInnerDataset());
        Class<?> outputClass = inner.getOutputFactory().getUnknownOutput().getClass();
        OutputInfo<?> outputDomain = OutputInfo.deserialize(proto.getOutputDomain());
        Set<?> domain = outputDomain.getDomain();
        for (Object o : domain) {
            if (o.getClass().equals(outputClass)) continue;
            throw new IllegalStateException("Invalid protobuf, output domains do not match, expected " + outputClass + " found " + o.getClass());
        }
        FeatureMap featureDomain = FeatureMap.deserialize(proto.getFeatureDomain());
        int[] indices = Util.toPrimitiveInt(proto.getIndicesList());
        if (!DatasetView.validateIndices(inner.size(), indices)) {
            throw new IllegalStateException("Invalid protobuf, indices are not all inside the range of the inner dataset");
        }
        for (int i = 0; i < indices.length; ++i) {
            Example<?> example = inner.getExample(indices[i]);
            for (Feature f : example) {
                if (featureDomain.get(f.getName()) != null) continue;
                throw new IllegalStateException("Invalid protobuf, feature domain does not contain feature " + f.getName() + " present in an example");
            }
        }
        if (!(featureDomain instanceof ImmutableFeatureMap)) {
            throw new IllegalStateException("Invalid protobuf, feature map was not immutable");
        }
        if (!(outputDomain instanceof ImmutableOutputInfo)) {
            throw new IllegalStateException("Invalid protobuf, output info was not immutable");
        }
        return new DatasetView(inner, indices, proto.getSeed(), proto.getTag(), (ImmutableFeatureMap)featureDomain, (ImmutableOutputInfo)outputDomain, proto.getSampled(), proto.getWeighted(), proto.getStoreIndices());
    }

    public static <T extends Output<T>> DatasetView<T> createView(Dataset<T> dataset, Predicate<Example<T>> predicate, String tag) {
        ArrayList<Integer> selectedIndices = new ArrayList<Integer>();
        int i = 0;
        for (Example<T> e : dataset) {
            if (predicate.test(e)) {
                selectedIndices.add(i);
            }
            ++i;
        }
        int[] exampleIndices = Util.toPrimitiveInt(selectedIndices);
        return new DatasetView<T>(dataset, exampleIndices, tag);
    }

    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed) {
        return DatasetView.createBootstrapView(dataset, size, seed, dataset.getFeatureIDMap(), dataset.getOutputIDInfo());
    }

    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) {
        int[] bootstrapIndices = Util.generateBootstrapIndices(size, dataset.size(), new SplittableRandom(seed));
        return new DatasetView<T>(dataset, bootstrapIndices, seed, featureIDs, outputIDs, false);
    }

    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights) {
        return DatasetView.createWeightedBootstrapView(dataset, size, seed, exampleWeights, dataset.getFeatureIDMap(), dataset.getOutputIDInfo());
    }

    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) {
        if (dataset.size() != exampleWeights.length) {
            throw new IllegalArgumentException("There must be a weight for each example, dataset.size()=" + dataset.size() + ", exampleWeights.length=" + exampleWeights.length);
        }
        int[] bootstrapIndices = Util.generateWeightedIndicesSample(size, exampleWeights, new SplittableRandom(seed));
        return new DatasetView<T>(dataset, bootstrapIndices, seed, featureIDs, outputIDs, true);
    }

    public boolean storeIndicesInProvenance() {
        return this.storeIndices;
    }

    public void setStoreIndices(boolean storeIndices) {
        this.storeIndices = storeIndices;
    }

    @Override
    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("DatasetView(innerDataset=");
        buffer.append(this.innerDataset.getSourceDescription());
        buffer.append(",size=");
        buffer.append(this.size);
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(",tag=");
        buffer.append(this.tag);
        buffer.append(")");
        return buffer.toString();
    }

    @Override
    public Set<T> getOutputs() {
        return this.innerDataset.getOutputs();
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public ImmutableFeatureMap getFeatureMap() {
        return this.featureIDMap;
    }

    @Override
    public ImmutableOutputInfo<T> getOutputInfo() {
        return this.outputIDInfo;
    }

    @Override
    public Iterator<Example<T>> iterator() {
        return new ViewIterator(this);
    }

    @Override
    public List<Example<T>> getData() {
        ArrayList<Example<T>> data = new ArrayList<Example<T>>();
        for (int index : this.exampleIndices) {
            data.add(this.innerDataset.getExample(index));
        }
        return Collections.unmodifiableList(data);
    }

    @Override
    public Example<T> getExample(int index) {
        if (index < 0 || index >= this.size()) {
            throw new IllegalArgumentException("Example index " + index + " is out of bounds.");
        }
        return this.innerDataset.getExample(this.exampleIndices[index]);
    }

    @Override
    public DatasetViewProvenance getProvenance() {
        return new DatasetViewProvenance(this, this.storeIndices);
    }

    public String getTag() {
        return this.tag;
    }

    public int[] getExampleIndices() {
        return Arrays.copyOf(this.exampleIndices, this.exampleIndices.length);
    }

    @Override
    public DatasetProto serialize() {
        DatasetViewProto.Builder datasetBuilder = DatasetViewProto.newBuilder();
        datasetBuilder.setInnerDataset((DatasetProto)this.innerDataset.serialize());
        datasetBuilder.setSize(this.size);
        for (int i = 0; i < this.exampleIndices.length; ++i) {
            datasetBuilder.addIndices(this.exampleIndices[i]);
        }
        datasetBuilder.setSeed(this.seed);
        datasetBuilder.setTag(this.tag);
        datasetBuilder.setSampled(this.sampled);
        datasetBuilder.setWeighted(this.weighted);
        datasetBuilder.setStoreIndices(this.storeIndices);
        datasetBuilder.setFeatureDomain(this.featureIDMap.serialize());
        datasetBuilder.setOutputDomain((OutputDomainProto)this.outputIDInfo.serialize());
        DatasetProto.Builder builder = DatasetProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(DatasetView.class.getName());
        builder.setSerializedData(Any.pack((Message)datasetBuilder.build()));
        return builder.build();
    }

    private static boolean validateIndices(int size, int[] indices) {
        boolean valid = true;
        for (int i = 0; i < indices.length; ++i) {
            int idx = indices[i];
            valid &= idx < size && idx > -1;
        }
        return valid;
    }

    private static final class ViewIterator<T extends Output<T>>
    implements Iterator<Example<T>> {
        private int counter = 0;
        private final DatasetView<T> dataset;

        ViewIterator(DatasetView<T> dataset) {
            this.dataset = dataset;
        }

        @Override
        public boolean hasNext() {
            return this.counter < this.dataset.size();
        }

        @Override
        public Example<T> next() {
            Example<T> example = this.dataset.getExample(this.counter);
            ++this.counter;
            return example;
        }
    }

    public static final class DatasetViewProvenance
    extends DatasetProvenance {
        private static final long serialVersionUID = 1L;
        private static final String SIZE = "size";
        private static final String SEED = "seed";
        private static final String TAG = "tag";
        private static final String SAMPLED = "sampled";
        private static final String WEIGHTED = "weighted";
        private static final String INDICES = "indices";
        private final IntProvenance size;
        private final LongProvenance seed;
        private final StringProvenance tag;
        private final BooleanProvenance weighted;
        private final BooleanProvenance sampled;
        private final int[] indices;

        <T extends Output<T>> DatasetViewProvenance(DatasetView<T> dataset, boolean storeIndices) {
            super(((DatasetView)dataset).sourceProvenance, (ListProvenance<ObjectProvenance>)new ListProvenance(), dataset);
            this.size = new IntProvenance(SIZE, ((DatasetView)dataset).size);
            this.seed = new LongProvenance(SEED, ((DatasetView)dataset).seed);
            this.weighted = new BooleanProvenance(WEIGHTED, ((DatasetView)dataset).weighted);
            this.sampled = new BooleanProvenance(SAMPLED, ((DatasetView)dataset).sampled);
            this.tag = new StringProvenance(TAG, ((DatasetView)dataset).tag);
            this.indices = storeIndices ? ((DatasetView)dataset).indices : new int[]{};
        }

        public DatasetViewProvenance(Map<String, Provenance> map) {
            super(map);
            this.size = (IntProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SIZE, IntProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            this.seed = (LongProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SEED, LongProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            this.tag = (StringProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)TAG, StringProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            this.weighted = (BooleanProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)WEIGHTED, BooleanProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            this.sampled = (BooleanProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SAMPLED, BooleanProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            ListProvenance listIndices = (ListProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)INDICES, ListProvenance.class, (String)DatasetViewProvenance.class.getSimpleName());
            if (listIndices.getList().size() > 0) {
                try {
                    IntProvenance intProvenance = (IntProvenance)listIndices.getList().get(0);
                }
                catch (ClassCastException e) {
                    throw new ProvenanceException("Loaded another class when expecting an ListProvenance<IntProvenance>", (Throwable)e);
                }
            }
            this.indices = Util.toPrimitiveInt(ProvenanceUtil.unwrap((ListProvenance)listIndices));
        }

        public int[] generateBootstrap() {
            return Util.generateBootstrapIndices((int)this.size.getValue(), new SplittableRandom(this.seed.getValue()));
        }

        public boolean isSampled() {
            return this.sampled.getValue();
        }

        public boolean isWeighted() {
            return this.weighted.getValue();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof DatasetViewProvenance)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            DatasetViewProvenance pairs = (DatasetViewProvenance)o;
            return this.size.equals((Object)pairs.size) && this.seed.equals((Object)pairs.seed) && this.tag.equals((Object)pairs.tag);
        }

        @Override
        public int hashCode() {
            return Objects.hash(super.hashCode(), this.size, this.seed, this.tag);
        }

        @Override
        protected List<Pair<String, Provenance>> allProvenances() {
            List<Pair<String, Provenance>> provenances = super.allProvenances();
            provenances.add((Pair<String, Provenance>)new Pair((Object)SIZE, (Object)this.size));
            provenances.add((Pair<String, Provenance>)new Pair((Object)SEED, (Object)this.seed));
            provenances.add((Pair<String, Provenance>)new Pair((Object)TAG, (Object)this.tag));
            provenances.add((Pair<String, Provenance>)new Pair((Object)WEIGHTED, (Object)this.weighted));
            provenances.add((Pair<String, Provenance>)new Pair((Object)SAMPLED, (Object)this.sampled));
            provenances.add((Pair<String, Provenance>)new Pair((Object)INDICES, this.boxArray()));
            return provenances;
        }

        private ListProvenance<IntProvenance> boxArray() {
            ArrayList<IntProvenance> list = new ArrayList<IntProvenance>();
            for (int i = 0; i < this.indices.length; ++i) {
                list.add(new IntProvenance(INDICES, this.indices[i]));
            }
            return new ListProvenance(list);
        }

        @Override
        public String toString() {
            List<Pair<String, Provenance>> provenances = super.allProvenances();
            provenances.add((Pair<String, Provenance>)new Pair((Object)SIZE, (Object)this.size));
            provenances.add((Pair<String, Provenance>)new Pair((Object)SEED, (Object)this.seed));
            provenances.add((Pair<String, Provenance>)new Pair((Object)TAG, (Object)this.tag));
            provenances.add((Pair<String, Provenance>)new Pair((Object)WEIGHTED, (Object)this.weighted));
            provenances.add((Pair<String, Provenance>)new Pair((Object)SAMPLED, (Object)this.sampled));
            provenances.add((Pair<String, Provenance>)new Pair((Object)INDICES, (Object)new ListProvenance()));
            StringBuilder sb = new StringBuilder();
            sb.append("DatasetView(");
            for (Pair<String, Provenance> p : provenances) {
                sb.append((String)p.getA());
                sb.append('=');
                sb.append(((Provenance)p.getB()).toString());
                sb.append(',');
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            return sb.toString();
        }
    }
}

