/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.data.DataOptions;
import org.tribuo.util.Util;

public final class TrainTestHelper {
    private static final Logger logger = Logger.getLogger(TrainTestHelper.class.getName());
    private static final LabelFactory factory = new LabelFactory();

    private TrainTestHelper() {
    }

    public static Model<Label> run(ConfigurationManager cm, DataOptions dataOptions, Trainer<Label> trainer) throws IOException {
        LabsLogFormatter.setAllLogFormatters();
        if (dataOptions.trainingPath == null || dataOptions.testingPath == null) {
            logger.info(cm.usage());
            logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath);
            throw new ArgumentException("training-file", "test-file", "Must supply both training and testing data.");
        }
        Pair data = dataOptions.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        logger.info("Training data has " + train.getFeatureIDMap().size() + " features.");
        Dataset test = (Dataset)data.getB();
        logger.info("Training using " + trainer.toString());
        long trainStart = System.currentTimeMillis();
        Model model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        long testStart = System.currentTimeMillis();
        LabelEvaluation evaluation = (LabelEvaluation)factory.getEvaluator().evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        if (model.generatesProbabilities()) {
            logger.info("Average AUC = " + evaluation.averageAUCROC(false));
            logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
        }
        System.out.println(evaluation.toString());
        System.out.println(evaluation.getConfusionMatrix().toString());
        if (dataOptions.outputPath != null) {
            dataOptions.saveModel(model);
        }
        return model;
    }
}

