/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sequence;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.logging.Logger;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.LabelSequenceEvaluation;
import org.tribuo.classification.sequence.LabelSequenceEvaluator;
import org.tribuo.classification.sequence.example.SequenceDataGenerator;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;
import org.tribuo.util.Util;

public class SeqTrainTest {
    private static final Logger logger = Logger.getLogger(SeqTrainTest.class.getName());

    public static void main(String[] args) throws ClassNotFoundException, IOException {
        SequenceDataset test;
        SequenceDataset train;
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        SeqTrainTestOptions o = new SeqTrainTestOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        switch (o.datasetName) {
            case "Gorilla": 
            case "gorilla": {
                logger.info("Generating gorilla dataset");
                train = SequenceDataGenerator.generateGorillaDataset(1);
                test = SequenceDataGenerator.generateGorillaDataset(1);
                break;
            }
            default: {
                if (o.trainDataset != null && o.testDataset != null) {
                    if (o.protobufFormat) {
                        logger.info("Loading protobuf format training data from " + o.trainDataset);
                        SequenceDataset tmpTrain = SequenceDataset.deserializeFromFile((Path)o.trainDataset);
                        train = SequenceDataset.castDataset((SequenceDataset)tmpTrain, Label.class);
                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
                        logger.info("Loading protobuf format testing data from " + o.testDataset);
                        SequenceDataset tmpTest = SequenceDataset.deserializeFromFile((Path)o.testDataset);
                        test = SequenceDataset.castDataset((SequenceDataset)tmpTest, Label.class);
                        logger.info(String.format("Loaded %d testing examples", test.size()));
                        break;
                    }
                    logger.info("Loading training data from " + o.trainDataset);
                    try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.trainDataset, new OpenOption[0])));
                         ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.testDataset, new OpenOption[0])));){
                        SequenceDataset tmpTest;
                        SequenceDataset tmpTrain;
                        train = tmpTrain = (SequenceDataset)ois.readObject();
                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
                        logger.info("Loading testing data from " + o.testDataset);
                        test = tmpTest = (SequenceDataset)oits.readObject();
                        logger.info(String.format("Loaded %d testing examples", test.size()));
                        break;
                    }
                }
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
        }
        logger.info("Training using " + o.trainer.toString());
        long trainStart = System.currentTimeMillis();
        SequenceModel model = o.trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
        long testStart = System.currentTimeMillis();
        LabelSequenceEvaluation evaluation = (LabelSequenceEvaluation)labelEvaluator.evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        System.out.println();
        System.out.println(evaluation.getConfusionMatrix().toString());
        if (o.outputPath != null) {
            if (o.writeProtobuf) {
                model.serializeToFile(o.outputPath);
            } else {
                try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(o.outputPath, new OpenOption[0]));){
                    oos.writeObject(model);
                }
            }
            logger.info("Serialized model to file: " + o.outputPath);
        }
    }

    public static class SeqTrainTestOptions
    implements Options {
        @Option(charName=100, longName="dataset-name", usage="Name of the example dataset, options are {gorilla}.")
        public String datasetName = "";
        @Option(charName=102, longName="output-path", usage="Path to serialize model to.")
        public Path outputPath;
        @Option(charName=117, longName="train-dataset", usage="Path to a serialised SequenceDataset used for training.")
        public Path trainDataset = null;
        @Option(charName=118, longName="test-dataset", usage="Path to a serialised SequenceDataset used for testing.")
        public Path testDataset = null;
        @Option(charName=116, longName="trainer-name", usage="Name of the trainer in the configuration file.")
        public SequenceTrainer<Label> trainer;
        @Option(charName=112, longName="protobuf-format-dataset", usage="Load the model from a protobuf. Optional")
        public boolean protobufFormat;
        @Option(longName="write-protobuf-model", usage="Write the model out in protobuf format.")
        public boolean writeProtobuf;

        public String getOptionsDescription() {
            return "Trains and tests a sequence classification model on the specified dataset.";
        }
    }
}

