/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.nn.Parameter;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Predicate;

public class DefaultTrainingConfig
implements TrainingConfig {
    private PairList<Initializer, Predicate<Parameter>> initializers = new PairList();
    private Optimizer optimizer;
    private Device[] devices;
    private Loss loss;
    private ExecutorService executorService;
    private List<Evaluator> evaluators;
    private List<TrainingListener> listeners;

    public DefaultTrainingConfig(Loss loss) {
        this.loss = loss;
        this.optimizer = Adam.builder().build();
        this.evaluators = new ArrayList<Evaluator>();
        this.listeners = new ArrayList<TrainingListener>();
    }

    public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) {
        this.initializers.add(initializer, parameter -> parameter.getType().equals((Object)type));
        return this;
    }

    public DefaultTrainingConfig optInitializer(Initializer initializer, String name) {
        this.initializers.add(initializer, parameter -> parameter.getName().equals(name));
        return this;
    }

    public DefaultTrainingConfig optInitializer(Initializer initializer, Predicate<Parameter> predicate) {
        this.initializers.add(initializer, predicate);
        return this;
    }

    public DefaultTrainingConfig optDevices(Device[] devices) {
        this.devices = devices;
        return this;
    }

    public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
        return this;
    }

    public DefaultTrainingConfig optExecutorService() {
        return this.optExecutorService(ForkJoinPool.commonPool());
    }

    public DefaultTrainingConfig optExecutorService(ExecutorService executorService) {
        this.executorService = executorService;
        return this;
    }

    public <T extends Evaluator> DefaultTrainingConfig addEvaluators(Collection<T> evaluators) {
        evaluators.forEach(this::addEvaluator);
        return this;
    }

    public DefaultTrainingConfig addEvaluator(Evaluator evaluator) {
        this.evaluators.add(evaluator);
        return this;
    }

    public DefaultTrainingConfig addTrainingListeners(TrainingListener ... listeners) {
        this.listeners.addAll(Arrays.asList(listeners));
        return this;
    }

    @Override
    public Device[] getDevices() {
        if (this.devices == null) {
            return Engine.getInstance().getDevices();
        }
        return this.devices;
    }

    @Override
    public PairList<Initializer, Predicate<Parameter>> getInitializers() {
        return this.initializers;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    @Override
    public Loss getLossFunction() {
        return this.loss;
    }

    @Override
    public ExecutorService getExecutorService() {
        return this.executorService;
    }

    @Override
    public List<Evaluator> getEvaluators() {
        return this.evaluators;
    }

    @Override
    public List<TrainingListener> getTrainingListeners() {
        return this.listeners;
    }
}

