/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.sgd;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.Collections;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public abstract class AbstractSGDTrainer<T extends Output<T>, U, V extends Model<T>, X extends FeedForwardParameters>
implements Trainer<T>,
WeightedExamples {
    private static final Logger logger = Logger.getLogger(AbstractSGDTrainer.class.getName());
    @Config(description="The gradient optimiser to use.")
    protected StochasticGradientOptimiser optimiser = new AdaGrad(1.0, 0.1);
    @Config(description="The number of gradient descent epochs.")
    protected int epochs = 5;
    @Config(description="Log values after this many updates.")
    protected int loggingInterval = -1;
    @Config(description="Minibatch size in SGD.")
    protected int minibatchSize = 1;
    @Config(description="Seed for the RNG used to shuffle elements.")
    protected long seed = 12345L;
    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
    protected boolean shuffle = true;
    protected final boolean addBias;
    protected SplittableRandom rng;
    private int trainInvocationCounter;

    protected AbstractSGDTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, boolean addBias) {
        this.optimiser = optimiser;
        this.epochs = epochs;
        this.loggingInterval = loggingInterval;
        this.minibatchSize = minibatchSize;
        this.seed = seed;
        this.addBias = addBias;
    }

    protected AbstractSGDTrainer(boolean addBias) {
        this.addBias = addBias;
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    public V train(Dataset<T> examples) {
        return this.train(examples, Collections.emptyMap());
    }

    public V train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public V train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        StochasticGradientOptimiser localOptimiser;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        AbstractSGDTrainer abstractSGDTrainer = this;
        synchronized (abstractSGDTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            localOptimiser = this.optimiser.copy();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        SGDObjective<U> objective = this.getObjective();
        ImmutableOutputInfo outputIDInfo = examples.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        int featureSpaceSize = featureIDMap.size();
        SGDVector[] sgdFeatures = new SGDVector[examples.size()];
        Object[] sgdTargets = new Object[examples.size()];
        double[] weights = new double[examples.size()];
        int n = 0;
        long featureSize = 0L;
        long denseCount = 0L;
        for (Example example : examples) {
            weights[n] = example.getWeight();
            if (example.size() == featureSpaceSize) {
                sgdFeatures[n] = DenseVector.createDenseVector((Example)example, (ImmutableFeatureMap)featureIDMap, (boolean)this.addBias);
                ++denseCount;
            } else {
                sgdFeatures[n] = SparseVector.createSparseVector((Example)example, (ImmutableFeatureMap)featureIDMap, (boolean)this.addBias);
            }
            sgdTargets[n] = this.getTarget(outputIDInfo, example.getOutput());
            featureSize += (long)sgdFeatures[n].numActiveElements();
            ++n;
        }
        logger.info(String.format("Training SGD model with %d examples", n));
        logger.fine("Mean number of active features = " + (double)featureSize / (double)n);
        logger.fine("Number of dense examples = " + denseCount);
        logger.info("Outputs - " + outputIDInfo.toReadableString());
        X parameters = this.createParameters(featureIDMap.size(), outputIDInfo.size(), localRNG);
        localOptimiser.initialise(parameters);
        double loss = 0.0;
        int iteration = 0;
        for (int i = 0; i < this.epochs; ++i) {
            if (this.shuffle) {
                AbstractSGDTrainer.shuffleInPlace(sgdFeatures, sgdTargets, weights, localRNG);
            }
            if (this.minibatchSize == 1) {
                for (int j = 0; j < sgdFeatures.length; ++j) {
                    DenseVector pred = parameters.predict(sgdFeatures[j]);
                    Pair<Double, SGDVector> output = objective.lossAndGradient(sgdTargets[j], (SGDVector)pred);
                    loss += (Double)output.getA() * weights[j];
                    Tensor[] updates = localOptimiser.step(parameters.gradients(output, sgdFeatures[j]), weights[j]);
                    parameters.update(updates);
                    if (this.loggingInterval == -1 || ++iteration % this.loggingInterval != 0) continue;
                    logger.info("At iteration " + iteration + ", average loss = " + loss / (double)this.loggingInterval);
                    loss = 0.0;
                }
                continue;
            }
            Tensor[][] gradients = new Tensor[this.minibatchSize][];
            for (int j = 0; j < sgdFeatures.length; j += this.minibatchSize) {
                double tempWeight = 0.0;
                int curSize = 0;
                for (int k = j; k < j + this.minibatchSize && k < sgdFeatures.length; ++k) {
                    DenseVector pred = parameters.predict(sgdFeatures[k]);
                    Pair<Double, SGDVector> output = objective.lossAndGradient(sgdTargets[k], (SGDVector)pred);
                    loss += (Double)output.getA() * weights[k];
                    tempWeight += weights[k];
                    gradients[k - j] = parameters.gradients(output, sgdFeatures[k]);
                    ++curSize;
                }
                Tensor[] updates = parameters.merge((Tensor[][])gradients, curSize);
                for (int k = 0; k < updates.length; ++k) {
                    updates[k].scaleInPlace((double)this.minibatchSize);
                }
                updates = localOptimiser.step(updates, tempWeight /= (double)this.minibatchSize);
                parameters.update(updates);
                if (this.loggingInterval == -1 || ++iteration % this.loggingInterval != 0) continue;
                logger.info("At iteration " + iteration + ", average loss = " + loss / (double)this.loggingInterval);
                loss = 0.0;
            }
        }
        localOptimiser.finalise();
        ModelProvenance provenance = new ModelProvenance(this.getModelClassName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        V model = this.createModel(this.getName(), provenance, featureIDMap, outputIDInfo, parameters);
        localOptimiser.reset();
        return model;
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    protected abstract U getTarget(ImmutableOutputInfo<T> var1, T var2);

    protected abstract SGDObjective<U> getObjective();

    protected abstract V createModel(String var1, ModelProvenance var2, ImmutableFeatureMap var3, ImmutableOutputInfo<T> var4, X var5);

    protected abstract String getModelClassName();

    protected abstract String getName();

    protected abstract X createParameters(int var1, int var2, SplittableRandom var3);

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public static <T> void shuffleInPlace(SGDVector[] features, T[] labels, double[] weights, SplittableRandom rng) {
        int size;
        for (int i = size = features.length; i > 1; --i) {
            int j = rng.nextInt(i);
            SGDVector tmpFeature = features[i - 1];
            features[i - 1] = features[j];
            features[j] = tmpFeature;
            T tmpLabel = labels[i - 1];
            labels[i - 1] = labels[j];
            labels[j] = tmpLabel;
            double tmpWeight = weights[i - 1];
            weights[i - 1] = weights[j];
            weights[j] = tmpWeight;
        }
    }
}

