/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.Config;
import java.util.logging.Logger;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.common.sgd.AbstractFMTrainer;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.sgd.RegressionObjective;
import org.tribuo.regression.sgd.fm.FMRegressionModel;

public class FMRegressionTrainer
extends AbstractFMTrainer<Regressor, DenseVector, FMRegressionModel> {
    private static final Logger logger = Logger.getLogger(FMRegressionTrainer.class.getName());
    @Config(mandatory=true, description="The regression objective to use.")
    private RegressionObjective objective;
    @Config(mandatory=true, description="Standardise the output variables before fitting the model.")
    private boolean standardise;

    public FMRegressionTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance, boolean standardise) {
        super(optimiser, epochs, loggingInterval, minibatchSize, seed, factorizedDimSize, variance);
        this.objective = objective;
        this.standardise = standardise;
    }

    public FMRegressionTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed, int factorizedDimSize, double variance, boolean standardise) {
        this(objective, optimiser, epochs, loggingInterval, 1, seed, factorizedDimSize, variance, standardise);
    }

    public FMRegressionTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed, int factorizedDimSize, double variance, boolean standardise) {
        this(objective, optimiser, epochs, 1000, 1, seed, factorizedDimSize, variance, standardise);
    }

    private FMRegressionTrainer() {
    }

    protected DenseVector getTarget(ImmutableOutputInfo<Regressor> outputInfo, Regressor output) {
        ImmutableRegressionInfo regressionInfo = (ImmutableRegressionInfo)outputInfo;
        double[] regressorsBuffer = new double[outputInfo.size()];
        for (Regressor.DimensionTuple r : output) {
            int id = outputInfo.getID((Output)r);
            double curValue = r.getValue();
            if (this.standardise) {
                curValue = (curValue - regressionInfo.getMean(id)) / regressionInfo.getVariance(id);
            }
            regressorsBuffer[id] = curValue;
        }
        return DenseVector.createDenseVector((double[])regressorsBuffer);
    }

    protected SGDObjective<DenseVector> getObjective() {
        return this.objective;
    }

    protected FMRegressionModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo, FMParameters parameters) {
        String[] dimensionNames = new String[outputInfo.size()];
        for (Regressor r : outputInfo.getDomain()) {
            int id = outputInfo.getID((Output)r);
            dimensionNames[id] = r.getNames()[0];
        }
        return new FMRegressionModel(name, dimensionNames, provenance, featureMap, outputInfo, parameters, this.standardise);
    }

    protected String getModelClassName() {
        return FMRegressionModel.class.getName();
    }

    public String toString() {
        return "FMRegressionTrainer(objective=" + this.objective.toString() + ",optimiser=" + this.optimiser.toString() + ",epochs=" + this.epochs + ",minibatchSize=" + this.minibatchSize + ",seed=" + this.seed + ",factorizedDimSize=" + this.factorizedDimSize + ",variance=" + this.variance + ",standardise=" + this.standardise + ")";
    }
}

