/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.common.optimizer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
import org.apache.flink.ml.common.lossfunc.LossFunc;
import org.apache.flink.ml.common.optimizer.Optimizer;
import org.apache.flink.ml.common.optimizer.RegularizationUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;

@Internal
public class SGD
implements Optimizer {
    private final SGDParams params;

    public SGD(int maxIter, double learningRate, int globalBatchSize, double tol, double reg, double elasticNet) {
        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
    }

    @Override
    public DataStream<DenseVector> optimize(DataStream<DenseVector> initModelData, DataStream<LabeledPointWithWeight> trainData, LossFunc lossFunc) {
        DataStreamList resultList = Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{initModelData.broadcast().map((MapFunction & Serializable)modelVec -> modelVec.values)}), ReplayableDataStreamList.notReplay(new DataStream[]{trainData.rebalance().map((MapFunction & Serializable)x -> x)}), IterationConfig.newBuilder().build(), new TrainIterationBody(lossFunc, this.params));
        return resultList.get(0);
    }

    private static class SGDParams
    implements Serializable {
        public final int maxIter;
        public final double learningRate;
        public final int globalBatchSize;
        public final double tol;
        public final double reg;
        public final double elasticNet;

        private SGDParams(int maxIter, double learningRate, int globalBatchSize, double tol, double reg, double elasticNet) {
            this.maxIter = maxIter;
            this.learningRate = learningRate;
            this.globalBatchSize = globalBatchSize;
            this.tol = tol;
            this.reg = reg;
            this.elasticNet = elasticNet;
        }
    }

    private static class CacheDataAndDoTrain
    extends AbstractStreamOperator<double[]>
    implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
    IterationListener<double[]> {
        private final SGDParams params;
        private final LossFunc lossFunc;
        private final OutputTag<DenseVector> modelDataOutputTag;
        private List<LabeledPointWithWeight> trainData;
        private ListState<LabeledPointWithWeight> trainDataState;
        private int nextBatchOffset = 0;
        private ListState<Integer> nextBatchOffsetState;
        private DenseVector coefficient;
        private ListState<DenseVector> coefficientState;
        private int coefficientDim;
        private double[] feedbackArray;
        private ListState<double[]> feedbackArrayState;
        private int localBatchSize;

        private CacheDataAndDoTrain(LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
            this.lossFunc = lossFunc;
            this.params = params;
            this.modelDataOutputTag = modelDataOutputTag;
        }

        public void open() {
            int numTasks = this.getRuntimeContext().getNumberOfParallelSubtasks();
            int taskId = this.getRuntimeContext().getIndexOfThisSubtask();
            this.localBatchSize = this.params.globalBatchSize / numTasks;
            if (this.params.globalBatchSize % numTasks > taskId) {
                ++this.localBatchSize;
            }
        }

        private double getTotalWeight() {
            return this.feedbackArray[this.coefficientDim];
        }

        private void setTotalWeight(double totalWeight) {
            this.feedbackArray[this.coefficientDim] = totalWeight;
        }

        private double getTotalLoss() {
            return this.feedbackArray[this.coefficientDim + 1];
        }

        private void setTotalLoss(double totalLoss) {
            this.feedbackArray[this.coefficientDim + 1] = totalLoss;
        }

        private void updateModel() {
            if (this.getTotalWeight() > 0.0) {
                BLAS.axpy(-this.params.learningRate / this.getTotalWeight(), (Vector)new DenseVector(this.feedbackArray), this.coefficient, this.coefficientDim);
                double regLoss = RegularizationUtils.regularize(this.coefficient, this.params.reg, this.params.elasticNet, this.params.learningRate);
                this.setTotalLoss(this.getTotalLoss() + regLoss);
            }
        }

        @Override
        public void onEpochWatermarkIncremented(int epochWatermark, IterationListener.Context context, Collector<double[]> collector) throws Exception {
            if (epochWatermark == 0) {
                this.coefficient = new DenseVector(this.feedbackArray);
                this.coefficientDim = this.coefficient.size();
                this.feedbackArray = new double[this.coefficient.size() + 2];
            } else {
                this.updateModel();
            }
            if (this.trainData == null) {
                this.trainData = IteratorUtils.toList(((Iterable)this.trainDataState.get()).iterator());
            }
            if (this.trainData.size() > 0) {
                List<LabeledPointWithWeight> miniBatchData = this.trainData.subList(this.nextBatchOffset, Math.min(this.nextBatchOffset + this.localBatchSize, this.trainData.size()));
                this.nextBatchOffset += this.localBatchSize;
                this.nextBatchOffset = this.nextBatchOffset >= this.trainData.size() ? 0 : this.nextBatchOffset;
                Arrays.fill(this.feedbackArray, 0.0);
                double totalLoss = 0.0;
                double totalWeight = 0.0;
                DenseVector cumGradientsWrapper = new DenseVector(this.feedbackArray);
                for (LabeledPointWithWeight dataPoint : miniBatchData) {
                    totalLoss += this.lossFunc.computeLoss(dataPoint, this.coefficient);
                    this.lossFunc.computeGradient(dataPoint, this.coefficient, cumGradientsWrapper);
                    totalWeight += dataPoint.getWeight();
                }
                this.setTotalLoss(totalLoss);
                this.setTotalWeight(totalWeight);
                collector.collect((Object)this.feedbackArray);
            }
        }

        @Override
        public void onIterationTerminated(IterationListener.Context context, Collector<double[]> collector) {
            this.trainDataState.clear();
            if (this.getRuntimeContext().getIndexOfThisSubtask() == 0) {
                this.updateModel();
                context.output(this.modelDataOutputTag, this.coefficient);
            }
        }

        public void processElement1(StreamRecord<LabeledPointWithWeight> streamRecord) throws Exception {
            this.trainDataState.add((Object)((LabeledPointWithWeight)streamRecord.getValue()));
        }

        public void processElement2(StreamRecord<double[]> streamRecord) {
            this.feedbackArray = (double[])streamRecord.getValue();
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.coefficientState = context.getOperatorStateStore().getListState(new ListStateDescriptor("coefficientState", (TypeInformation)DenseVectorTypeInfo.INSTANCE));
            OperatorStateUtils.getUniqueElement(this.coefficientState, "coefficientState").ifPresent(x -> {
                this.coefficient = x;
            });
            if (this.coefficient != null) {
                this.coefficientDim = this.coefficient.size();
            }
            this.feedbackArrayState = context.getOperatorStateStore().getListState(new ListStateDescriptor("feedbackArrayState", (TypeInformation)PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
            OperatorStateUtils.getUniqueElement(this.feedbackArrayState, "feedbackArrayState").ifPresent(x -> {
                this.feedbackArray = x;
            });
            this.trainDataState = context.getOperatorStateStore().getListState(new ListStateDescriptor("trainDataState", TypeInformation.of(LabeledPointWithWeight.class)));
            this.nextBatchOffsetState = context.getOperatorStateStore().getListState(new ListStateDescriptor("nextBatchOffsetState", (TypeInformation)BasicTypeInfo.INT_TYPE_INFO));
            this.nextBatchOffset = OperatorStateUtils.getUniqueElement(this.nextBatchOffsetState, "nextBatchOffsetState").orElse(0);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            this.coefficientState.clear();
            if (this.coefficient != null) {
                this.coefficientState.add((Object)this.coefficient);
            }
            this.feedbackArrayState.clear();
            if (this.feedbackArray != null) {
                this.feedbackArrayState.add((Object)this.feedbackArray);
            }
            this.nextBatchOffsetState.clear();
            this.nextBatchOffsetState.add((Object)this.nextBatchOffset);
        }
    }

    private static class TrainIterationBody
    implements IterationBody {
        private final LossFunc lossFunc;
        private final SGDParams params;

        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
            this.lossFunc = lossFunc;
            this.params = params;
        }

        @Override
        public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) {
            DataStream variableStream = variableStreams.get(0);
            DataStream trainData = dataStreams.get(0);
            OutputTag<DenseVector> modelDataOutputTag = new OutputTag<DenseVector>("MODEL_OUTPUT"){};
            SingleOutputStreamOperator modelUpdateAndWeightAndLoss = trainData.connect(variableStream).transform("CacheDataAndDoTrain", (TypeInformation)PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, (TwoInputStreamOperator)new CacheDataAndDoTrain(this.lossFunc, this.params, (OutputTag)modelDataOutputTag));
            DataStreamList feedbackVariableStream = IterationBody.forEachRound(DataStreamList.of(new DataStream[]{modelUpdateAndWeightAndLoss}), input -> {
                DataStream<double[]> feedback = DataStreamUtils.allReduceSum(input.get(0));
                return DataStreamList.of(feedback);
            });
            SingleOutputStreamOperator terminationCriteria = feedbackVariableStream.get(0).map((MapFunction & Serializable)reducedUpdateAndWeightAndLoss -> {
                double[] value = (double[])reducedUpdateAndWeightAndLoss;
                return value[value.length - 1] / value[value.length - 2];
            }).flatMap((FlatMapFunction)new TerminateOnMaxIterOrTol(this.params.maxIter, this.params.tol));
            return new IterationBodyResult(DataStreamList.of(feedbackVariableStream.get(0)), DataStreamList.of(new DataStream[]{modelUpdateAndWeightAndLoss.getSideOutput((OutputTag)modelDataOutputTag)}), (DataStream<?>)terminationCriteria);
        }
    }
}

