/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.crf;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.logging.Logger;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.LabelSequenceEvaluation;
import org.tribuo.classification.sequence.LabelSequenceEvaluator;
import org.tribuo.classification.sequence.example.SequenceDataGenerator;
import org.tribuo.classification.sgd.crf.CRFModel;
import org.tribuo.classification.sgd.crf.CRFTrainer;
import org.tribuo.hash.HashCodeHasher;
import org.tribuo.hash.Hasher;
import org.tribuo.hash.HashingOptions;
import org.tribuo.hash.MessageDigestHasher;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.GradientOptimiserOptions;
import org.tribuo.sequence.HashingSequenceTrainer;
import org.tribuo.sequence.MutableSequenceDataset;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;
import org.tribuo.util.Util;

public class SeqTest {
    private static final Logger logger = Logger.getLogger(SeqTest.class.getName());

    public static void main(String[] args) throws ClassNotFoundException, IOException {
        MutableSequenceDataset test;
        MutableSequenceDataset train;
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        CRFOptions o = new CRFOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        logger.info("Configuring gradient optimiser");
        StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
        logger.info(String.format("Set logging interval to %d", o.loggingInterval));
        switch (o.datasetName) {
            case "Gorilla": 
            case "gorilla": {
                logger.info("Generating gorilla dataset");
                train = SequenceDataGenerator.generateGorillaDataset((int)1);
                test = SequenceDataGenerator.generateGorillaDataset((int)1);
                break;
            }
            default: {
                if (o.trainDataset != null && o.testDataset != null) {
                    if (o.protobufFormat) {
                        logger.info("Loading protobuf format training data from " + o.trainDataset);
                        SequenceDataset tmpTrain = SequenceDataset.deserializeFromFile((Path)o.trainDataset);
                        train = SequenceDataset.castDataset((SequenceDataset)tmpTrain, Label.class);
                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
                        logger.info("Loading protobuf format testing data from " + o.testDataset);
                        SequenceDataset tmpTest = SequenceDataset.deserializeFromFile((Path)o.testDataset);
                        test = SequenceDataset.castDataset((SequenceDataset)tmpTest, Label.class);
                        logger.info(String.format("Loaded %d testing examples", test.size()));
                        break;
                    }
                    logger.info("Loading training data from " + o.trainDataset);
                    try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.trainDataset, new OpenOption[0])));
                         ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.testDataset, new OpenOption[0])));){
                        SequenceDataset tmpTrain = (SequenceDataset)ois.readObject();
                        train = tmpTrain;
                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
                        logger.info("Loading testing data from " + o.testDataset);
                        SequenceDataset tmpTest = (SequenceDataset)oits.readObject();
                        test = tmpTest;
                        logger.info(String.format("Loaded %d testing examples", test.size()));
                        break;
                    }
                }
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
        }
        CRFTrainer trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
        trainer.setShuffle(o.shuffle);
        switch (o.modelHashingAlgorithm) {
            case NONE: {
                break;
            }
            case HC: {
                trainer = new HashingSequenceTrainer((SequenceTrainer)trainer, (Hasher)new HashCodeHasher(o.modelHashingSalt));
                break;
            }
            case SHA1: {
                trainer = new HashingSequenceTrainer((SequenceTrainer)trainer, (Hasher)new MessageDigestHasher("SHA1", o.modelHashingSalt));
                break;
            }
            case SHA256: {
                trainer = new HashingSequenceTrainer((SequenceTrainer)trainer, (Hasher)new MessageDigestHasher("SHA-256", o.modelHashingSalt));
                break;
            }
            default: {
                logger.info("Unknown hasher " + o.modelHashingAlgorithm);
            }
        }
        logger.info("Training using " + ((Object)trainer).toString());
        long trainStart = System.currentTimeMillis();
        CRFModel model = (CRFModel)trainer.train((SequenceDataset)train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        if (o.logModel) {
            System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
            System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
            System.out.println("Features - " + model.generateWeightsString());
        }
        LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
        long testStart = System.currentTimeMillis();
        LabelSequenceEvaluation evaluation = (LabelSequenceEvaluation)labelEvaluator.evaluate((SequenceModel)model, (SequenceDataset)test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        System.out.println();
        System.out.println(evaluation.getConfusionMatrix().toString());
        if (o.outputPath != null) {
            if (o.protobufFormat) {
                model.serializeToFile(o.outputPath);
            } else {
                try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(o.outputPath, new OpenOption[0]));){
                    oos.writeObject((Object)model);
                }
            }
            logger.info("Serialized model to file: " + o.outputPath);
        }
    }

    public static class CRFOptions
    implements Options {
        public GradientOptimiserOptions gradientOptions;
        @Option(charName=100, longName="dataset-name", usage="Name of the example dataset, options are {gorilla}.")
        public String datasetName = "";
        @Option(charName=102, longName="output-path", usage="Path to serialize model to.")
        public Path outputPath;
        @Option(charName=105, longName="epochs", usage="Number of SGD epochs.")
        public int epochs = 5;
        @Option(charName=111, longName="print-model", usage="Print out feature, label and other model details.")
        public boolean logModel = false;
        @Option(charName=112, longName="logging-interval", usage="Log the objective after <int> examples.")
        public int loggingInterval = 100;
        @Option(charName=114, longName="seed", usage="RNG seed.")
        public long seed = 1L;
        @Option(longName="shuffle", usage="Shuffle the data each epoch (default: true).")
        public boolean shuffle = true;
        @Option(charName=117, longName="train-dataset", usage="Path to a serialised SequenceDataset used for training.")
        public Path trainDataset = null;
        @Option(charName=118, longName="test-dataset", usage="Path to a serialised SequenceDataset used for testing.")
        public Path testDataset = null;
        @Option(longName="model-hashing-algorithm", usage="Hash the model during training. Defaults to no hashing.")
        public HashingOptions.ModelHashingType modelHashingAlgorithm = HashingOptions.ModelHashingType.NONE;
        @Option(longName="model-hashing-salt", usage="Salt for hashing the model.")
        public String modelHashingSalt = "";
        @Option(longName="protobuf-model", usage="Load the model from a protobuf. Optional")
        public boolean protobufFormat;

        public String getOptionsDescription() {
            return "Tests a linear chain CRF model on the specified dataset.";
        }
    }
}

