/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.dataset;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.ProtocolStringList;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.MutableFeatureMap;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.impl.ArrayExample;
import org.tribuo.impl.DatasetDataCarrier;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.DatasetProto;
import org.tribuo.protos.core.ExampleProto;
import org.tribuo.protos.core.SelectedFeatureDatasetProto;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.FeatureSetProvenance;

public final class SelectedFeatureDataset<T extends Output<T>>
extends ImmutableDataset<T> {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(SelectedFeatureDataset.class.getName());
    public static final int CURRENT_VERSION = 0;
    private final int k;
    private final SelectedFeatureSet featureSet;
    private final Set<String> selectedFeatures;
    private final int numExamplesRemoved;

    public SelectedFeatureDataset(Dataset<T> dataset, SelectedFeatureSet featureSet) {
        this(dataset, featureSet, -1);
    }

    public SelectedFeatureDataset(Dataset<T> dataset, SelectedFeatureSet featureSet, int k) {
        super((DataProvenance)dataset.getProvenance(), dataset.getOutputFactory());
        this.featureSet = featureSet;
        this.k = k;
        LinkedHashSet<String> tmpFeatures = new LinkedHashSet<String>();
        if (k == 0 || featureSet.featureNames().size() == 0) {
            throw new IllegalArgumentException("Tried to select zero features.");
        }
        if (k != -1 && !featureSet.isOrdered()) {
            throw new IllegalArgumentException("Tried to select the top " + k + " features from an unordered feature set.");
        }
        if (k > featureSet.featureNames().size()) {
            throw new IllegalArgumentException("Tried to select more features than are available in feature set, requested " + k + ", found " + featureSet.featureNames().size());
        }
        if (k > 0) {
            List<String> featureList = featureSet.featureNames();
            for (int i = 0; i < k; ++i) {
                tmpFeatures.add(featureList.get(i));
            }
        } else {
            if (k < -1) {
                throw new IllegalArgumentException("Supplied k " + k + " but only k == -1 or 1 < k < N} is allowed.");
            }
            tmpFeatures.addAll(featureSet.featureNames());
        }
        this.selectedFeatures = Collections.unmodifiableSet(tmpFeatures);
        FeatureMap wfm = dataset.getFeatureMap();
        HashSet<String> datasetFeatures = new HashSet<String>(wfm.keySet());
        datasetFeatures.retainAll(this.selectedFeatures);
        if (datasetFeatures.size() == 0) {
            throw new IllegalArgumentException("The selected feature set had no overlap with the supplied dataset.");
        }
        int tmpNumExamplesRemoved = 0;
        MutableFeatureMap featureMap = new MutableFeatureMap();
        MutableOutputInfo<T> outputInfo = dataset.getOutputFactory().generateInfo();
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (Example<T> ex : dataset) {
            features.clear();
            ArrayExample<T> copy = new ArrayExample<T>(ex);
            for (Feature f : ex) {
                if (this.selectedFeatures.contains(f.getName())) {
                    featureMap.add(f.getName(), f.getValue());
                    continue;
                }
                features.add(f);
            }
            if (features.size() > 0) {
                copy.removeFeatures(features);
            }
            if (copy.size() > 0) {
                this.data.add(copy);
                outputInfo.observe(ex.getOutput());
                continue;
            }
            ++tmpNumExamplesRemoved;
        }
        this.numExamplesRemoved = tmpNumExamplesRemoved;
        this.featureIDMap = new ImmutableFeatureMap(featureMap);
        this.outputIDInfo = outputInfo.generateImmutableOutputInfo();
        if (this.numExamplesRemoved > 0) {
            logger.info(String.format("filtered out %d examples because they had zero features after the selected feature set was applied.", this.numExamplesRemoved));
        }
    }

    private SelectedFeatureDataset(DataProvenance provenance, OutputFactory<T> factory, String tribuoVersion, ImmutableFeatureMap fmap, ImmutableOutputInfo<T> outputInfo, List<Example<T>> examples, int k, SelectedFeatureSet featureSet, Set<String> selectedFeatures, int numExamplesRemoved) {
        super(provenance, factory, tribuoVersion, fmap, outputInfo, examples, false);
        this.k = k;
        this.selectedFeatures = Collections.unmodifiableSet(selectedFeatures);
        this.featureSet = featureSet;
        this.numExamplesRemoved = numExamplesRemoved;
    }

    public static SelectedFeatureDataset<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        SelectedFeatureDatasetProto proto = (SelectedFeatureDatasetProto)message.unpack(SelectedFeatureDatasetProto.class);
        DatasetDataCarrier<?> carrier = DatasetDataCarrier.deserialize(proto.getMetadata());
        Class<?> outputClass = carrier.outputFactory().getUnknownOutput().getClass();
        FeatureMap fmap = carrier.featureDomain();
        ArrayList examples = new ArrayList();
        int idx = 0;
        for (ExampleProto e : proto.getExamplesList()) {
            Example<?> example = Example.deserialize(e);
            if (example.getOutput().getClass().equals(outputClass)) {
                for (Feature f : example) {
                    if (fmap.get(f.getName()) != null) continue;
                    throw new IllegalStateException("Invalid protobuf, feature domain does not contain feature " + f.getName() + " present in example at idx " + idx);
                }
            } else {
                throw new IllegalStateException("Invalid protobuf, expected all examples to have output class " + outputClass + ", but found " + example.getOutput().getClass() + " in example idx " + idx);
            }
            examples.add(example);
            ++idx;
        }
        if (!(fmap instanceof ImmutableFeatureMap)) {
            throw new IllegalStateException("Invalid protobuf, feature map was not immutable");
        }
        if (!(carrier.outputDomain() instanceof ImmutableOutputInfo)) {
            throw new IllegalStateException("Invalid protobuf, output info was not immutable");
        }
        int k = proto.getK();
        if (k < 1 && k != -1) {
            throw new IllegalStateException("Invalid protobuf, k must be positive or -1, found " + k);
        }
        int numRemoved = proto.getNumExamplesRemoved();
        if (numRemoved < 0) {
            throw new IllegalStateException("Invalid protobuf, number of examples removed must be non-negative, found " + numRemoved);
        }
        SelectedFeatureSet featureSet = (SelectedFeatureSet)ProtoUtil.deserialize(proto.getFeatureSet());
        ProtocolStringList featureList = proto.getSelectedFeaturesList();
        LinkedHashSet<String> selectedFeatures = new LinkedHashSet<String>((Collection<String>)featureList);
        if (selectedFeatures.size() != featureList.size()) {
            throw new IllegalStateException("Invalid protobuf, selected features contained duplicates, features = " + featureList);
        }
        for (String s : selectedFeatures) {
            if (fmap.get(s) != null) continue;
            throw new IllegalStateException("Invalid protobuf, some selected features were not found in the feature domain.");
        }
        return new SelectedFeatureDataset(carrier.provenance(), carrier.outputFactory(), carrier.tribuoVersion(), (ImmutableFeatureMap)fmap, (ImmutableOutputInfo)carrier.outputDomain(), examples, k, featureSet, selectedFeatures, numRemoved);
    }

    public int getNumExamplesRemoved() {
        return this.numExamplesRemoved;
    }

    public int getK() {
        return this.k;
    }

    public SelectedFeatureSet getFeatureSet() {
        return this.featureSet;
    }

    public Set<String> getSelectedFeatures() {
        return this.selectedFeatures;
    }

    @Override
    public DatasetProvenance getProvenance() {
        return new SelectedFeatureDatasetProvenance(this);
    }

    @Override
    public DatasetProto serialize() {
        SelectedFeatureDatasetProto.Builder datasetBuilder = SelectedFeatureDatasetProto.newBuilder();
        datasetBuilder.setMetadata(this.createDataCarrier(this.featureIDMap, this.outputIDInfo).serialize());
        for (Example e : this.data) {
            datasetBuilder.addExamples((ExampleProto)e.serialize());
        }
        datasetBuilder.setNumExamplesRemoved(this.numExamplesRemoved);
        datasetBuilder.setK(this.k);
        datasetBuilder.setFeatureSet(this.featureSet.serialize());
        datasetBuilder.addAllSelectedFeatures(this.selectedFeatures);
        DatasetProto.Builder builder = DatasetProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(SelectedFeatureDataset.class.getName());
        builder.setSerializedData(Any.pack((Message)datasetBuilder.build()));
        return builder.build();
    }

    public static final class SelectedFeatureDatasetProvenance
    extends DatasetProvenance {
        private static final long serialVersionUID = 1L;
        private static final String K = "k";
        private static final String FEATURE_SET_PROVENANCE = "feature-set-provenance";
        private static final String DATASET_PROVENANCE = "original-data-provenance";
        private final IntProvenance k;
        private final FeatureSetProvenance featureSetProvenance;
        private final DataProvenance datasetProvenance;

        <T extends Output<T>> SelectedFeatureDatasetProvenance(SelectedFeatureDataset<T> dataset) {
            super(((SelectedFeatureDataset)dataset).sourceProvenance, (ListProvenance<ObjectProvenance>)new ListProvenance(), dataset);
            this.k = new IntProvenance(K, ((SelectedFeatureDataset)dataset).k);
            this.featureSetProvenance = ((SelectedFeatureDataset)dataset).featureSet.getProvenance();
            this.datasetProvenance = ((SelectedFeatureDataset)dataset).sourceProvenance;
        }

        public SelectedFeatureDatasetProvenance(Map<String, Provenance> map) {
            super(map);
            this.k = (IntProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)K, IntProvenance.class, (String)SelectedFeatureDatasetProvenance.class.getSimpleName());
            this.featureSetProvenance = (FeatureSetProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)FEATURE_SET_PROVENANCE, FeatureSetProvenance.class, (String)SelectedFeatureDatasetProvenance.class.getSimpleName());
            this.datasetProvenance = (DataProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)DATASET_PROVENANCE, DataProvenance.class, (String)SelectedFeatureDatasetProvenance.class.getSimpleName());
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            SelectedFeatureDatasetProvenance pairs = (SelectedFeatureDatasetProvenance)o;
            return this.k.equals((Object)pairs.k) && this.featureSetProvenance.equals(pairs.featureSetProvenance) && this.datasetProvenance.equals(pairs.datasetProvenance);
        }

        @Override
        public int hashCode() {
            return Objects.hash(super.hashCode(), this.k, this.featureSetProvenance, this.datasetProvenance);
        }

        @Override
        protected List<Pair<String, Provenance>> allProvenances() {
            List<Pair<String, Provenance>> provenances = super.allProvenances();
            provenances.add((Pair<String, Provenance>)new Pair((Object)K, (Object)this.k));
            provenances.add((Pair<String, Provenance>)new Pair((Object)FEATURE_SET_PROVENANCE, (Object)this.featureSetProvenance));
            provenances.add((Pair<String, Provenance>)new Pair((Object)DATASET_PROVENANCE, (Object)this.datasetProvenance));
            return provenances;
        }
    }
}

