/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.util.ShrinkingMatrix;
import org.tribuo.math.optimisers.util.ShrinkingVector;

public class RMSProp
implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(RMSProp.class.getName());
    @Config(mandatory=true, description="Learning rate to scale the gradients by.")
    private double initialLearningRate;
    @Config(description="Momentum parameter.")
    private double rho = 0.9;
    @Config(description="Epsilon for numerical stability.")
    private double epsilon = 1.0E-8;
    @Config(description="Decay factor for the momentum.")
    private double decay = 0.0;
    private double invRho;
    private int iteration = 0;
    private Tensor[] gradsSquared;
    private DoubleUnaryOperator square;

    public RMSProp(double initialLearningRate, double rho, double epsilon, double decay) {
        this.initialLearningRate = initialLearningRate;
        this.rho = rho;
        this.epsilon = epsilon;
        this.decay = decay;
        this.iteration = 0;
        this.postConfig();
    }

    public RMSProp(double initialLearningRate, double rho) {
        this(initialLearningRate, rho, 1.0E-8, 0.0);
    }

    private RMSProp() {
    }

    public void postConfig() {
        this.invRho = 1.0 - this.rho;
        this.square = a -> this.invRho * a * a;
    }

    @Override
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        for (int i = 0; i < this.gradsSquared.length; ++i) {
            if (this.gradsSquared[i] instanceof DenseVector) {
                this.gradsSquared[i] = new ShrinkingVector((DenseVector)this.gradsSquared[i], this.invRho, false);
                continue;
            }
            if (this.gradsSquared[i] instanceof DenseMatrix) {
                this.gradsSquared[i] = new ShrinkingMatrix((DenseMatrix)this.gradsSquared[i], this.invRho, false);
                continue;
            }
            throw new IllegalStateException("Unknown Tensor subclass");
        }
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        double learningRate = this.initialLearningRate / (1.0 + this.decay * (double)this.iteration);
        DoubleUnaryOperator scale = a -> weight * learningRate / (this.epsilon + Math.sqrt(a));
        for (int i = 0; i < updates.length; ++i) {
            Tensor curGradsSquared = this.gradsSquared[i];
            Tensor curGrad = updates[i];
            curGradsSquared.intersectAndAddInPlace(curGrad, this.square);
            curGrad.hadamardProductInPlace(curGradsSquared, scale);
        }
        ++this.iteration;
        return updates;
    }

    public String toString() {
        return "RMSProp(initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ",epsilon=" + this.epsilon + ",decay=" + this.decay + ")";
    }

    @Override
    public void reset() {
        this.gradsSquared = null;
        this.iteration = 0;
    }

    @Override
    public RMSProp copy() {
        return new RMSProp(this.initialLearningRate, this.rho, this.epsilon, this.decay);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

