/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.PSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

public class FederatedPSControlThread
extends PSWorker
implements Callable<Void> {
    private static final long serialVersionUID = 6846648059569648791L;
    protected static final Log LOG = LogFactory.getLog((String)ParamServer.class.getName());
    private FederatedData _featuresData;
    private FederatedData _labelsData;
    private final long _modelVarID;
    private final Statement.PSRuntimeBalancing _runtimeBalancing;
    private int _numBatchesPerEpoch;
    private int _possibleBatchesPerLocalEpoch;
    private final boolean _weighting;
    private double _weightingFactor = 1.0;
    private boolean _cycleStartAt0 = false;

    public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {
        super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
        this._numBatchesPerEpoch = numBatchesPerGlobalEpoch;
        this._runtimeBalancing = runtimeBalancing;
        this._weighting = weighting;
        this._modelVarID = FederationUtils.getNextFedDataID();
    }

    public void setup(double weightingFactor) {
        this.incWorkerNumber();
        this._featuresData = this._features.getFedMapping().getFederatedData()[0];
        this._labelsData = this._labels.getFedMapping().getFederatedData()[0];
        this._weightingFactor = weightingFactor;
        long dataSize = this._features.getNumRows();
        if (this._runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH) {
            this._batchSize = (int)Math.ceil((double)dataSize / (double)this._numBatchesPerEpoch);
        }
        this._possibleBatchesPerLocalEpoch = (int)Math.ceil((double)dataSize / (double)this._batchSize);
        if (this._runtimeBalancing == Statement.PSRuntimeBalancing.NONE) {
            this._numBatchesPerEpoch = this._possibleBatchesPerLocalEpoch;
        }
        if (this._runtimeBalancing == Statement.PSRuntimeBalancing.BASELINE) {
            this._cycleStartAt0 = true;
        }
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("Setup config for worker " + this.getWorkerName()));
            LOG.info((Object)("Batch size: " + this._batchSize + " possible batches: " + this._possibleBatchesPerLocalEpoch + " batches to run: " + this._numBatchesPerEpoch + " weighting factor: " + this._weightingFactor));
        }
        ArrayList<ProgramBlock> pbs = new ArrayList<ProgramBlock>();
        BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(this._ec.getProgram());
        gradientProgramBlock.setInstructions(new ArrayList<Instruction>(Collections.singletonList(this._inst)));
        pbs.add(gradientProgramBlock);
        if (this._freq == Statement.PSFrequency.EPOCH) {
            BasicProgramBlock aggProgramBlock = new BasicProgramBlock(this._ec.getProgram());
            aggProgramBlock.setInstructions(new ArrayList<Instruction>(Collections.singletonList(this._ps.getAggInst())));
            pbs.add(aggProgramBlock);
        }
        String programSerialized = InstructionUtils.concatStrings(" PROG\u23a8", "\n", ProgramConverter.serializeProgram(this._ec.getProgram(), pbs, new HashMap<String, byte[]>()), "\u23ac");
        Future<FederatedResponse> udfResponse = this._featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, this._featuresData.getVarID(), new SetupFederatedWorker(this._batchSize, dataSize, this._possibleBatchesPerLocalEpoch, programSerialized, this._inst.getNamespace(), this._inst.getFunctionName(), this._ps.getAggInst().getFunctionName(), this._ec.getListObject("hyperparams"), this._modelVarID)));
        try {
            FederatedResponse response = udfResponse.get();
            if (!response.isSuccessful()) {
                throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed");
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage());
        }
    }

    public void teardown() {
        Future<FederatedResponse> udfResponse = this._featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, this._featuresData.getVarID(), new TeardownFederatedWorker()));
        try {
            FederatedResponse response = udfResponse.get();
            if (!response.isSuccessful()) {
                throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
        }
    }

    @Override
    public Void call() throws Exception {
        try {
            switch (this._freq) {
                case BATCH: {
                    this.computeWithBatchUpdates();
                    break;
                }
                case EPOCH: {
                    this.computeWithEpochUpdates();
                    break;
                }
                default: {
                    throw new DMLRuntimeException(String.format("%s not support update frequency %s", new Object[]{this.getWorkerName(), this._freq}));
                }
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException(String.format("%s failed", this.getWorkerName()), e);
        }
        this.teardown();
        return null;
    }

    protected ListObject pullModel() {
        return this._ps.pull(this._workerID);
    }

    protected void weighAndPushGradients(ListObject gradients) {
        if (this._weighting && this._weightingFactor != 1.0) {
            Timing tWeighting = DMLScript.STATISTICS ? new Timing(true) : null;
            gradients.getData().parallelStream().forEach(matrix -> {
                MatrixObject matrixObject = (MatrixObject)matrix;
                MatrixBlock input = ((MatrixBlock)matrixObject.acquireReadAndRelease()).scalarOperations(new RightScalarOperator(Multiply.getMultiplyFnObject(), this._weightingFactor), new MatrixBlock());
                matrixObject.acquireModify(input);
                matrixObject.release();
            });
            this.accFedPSGradientWeightingTime(tWeighting);
        }
        this._ps.push(this._workerID, gradients);
    }

    protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
        return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
    }

    protected void computeWithBatchUpdates() {
        for (int epochCounter = 0; epochCounter < this._epochs; ++epochCounter) {
            int currentLocalBatchNumber = this._cycleStartAt0 ? 0 : this._numBatchesPerEpoch * epochCounter % this._possibleBatchesPerLocalEpoch;
            for (int batchCounter = 0; batchCounter < this._numBatchesPerEpoch; ++batchCounter) {
                int localStartBatchNum = FederatedPSControlThread.getNextLocalBatchNum(currentLocalBatchNumber++, this._possibleBatchesPerLocalEpoch);
                ListObject model = this.pullModel();
                ListObject gradients = this.computeGradientsForNBatches(model, 1, localStartBatchNum);
                this.weighAndPushGradients(gradients);
                ParamservUtils.cleanupListObject(model);
                ParamservUtils.cleanupListObject(gradients);
            }
        }
    }

    protected void computeWithNBatchUpdates() {
        throw new NotImplementedException();
    }

    protected void computeWithEpochUpdates() {
        for (int epochCounter = 0; epochCounter < this._epochs; ++epochCounter) {
            int localStartBatchNum = this._cycleStartAt0 ? 0 : this._numBatchesPerEpoch * epochCounter % this._possibleBatchesPerLocalEpoch;
            ListObject model = this.pullModel();
            ListObject gradients = this.computeGradientsForNBatches(model, this._numBatchesPerEpoch, localStartBatchNum, true);
            this.weighAndPushGradients(gradients);
            ParamservUtils.cleanupListObject(model);
            ParamservUtils.cleanupListObject(gradients);
        }
    }

    protected ListObject computeGradientsForNBatches(ListObject model, int numBatchesToCompute, int localStartBatchNum) {
        return this.computeGradientsForNBatches(model, numBatchesToCompute, localStartBatchNum, false);
    }

    protected ListObject computeGradientsForNBatches(ListObject model, int numBatchesToCompute, int localStartBatchNum, boolean localUpdate) {
        Timing tFedCommunication = DMLScript.STATISTICS ? new Timing(true) : null;
        Future<FederatedResponse> putParamsResponse = this._featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, this._modelVarID, model));
        try {
            if (!putParamsResponse.get().isSuccessful()) {
                throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage());
        }
        Future<FederatedResponse> udfResponse = this._featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, this._featuresData.getVarID(), new federatedComputeGradientsForNBatches(new long[]{this._featuresData.getVarID(), this._labelsData.getVarID(), this._modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum)));
        try {
            Object[] responseData = udfResponse.get().getData();
            if (DMLScript.STATISTICS) {
                long total = (long)tFedCommunication.stop();
                long workerComputing = ((DoubleObject)responseData[1]).getLongValue();
                Statistics.accFedPSWorkerComputing(workerComputing);
                Statistics.accFedPSCommunicationTime(total - workerComputing);
            }
            return (ListObject)responseData[0];
        }
        catch (Exception e) {
            if (DMLScript.STATISTICS) {
                tFedCommunication.stop();
            }
            throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
        }
    }

    protected void accFedPSGradientWeightingTime(Timing time) {
        if (DMLScript.STATISTICS && time != null) {
            Statistics.accFedPSGradientWeightingTime((long)time.stop());
        }
    }

    @Override
    public String getWorkerName() {
        return String.format("Federated worker_%d", this._workerID);
    }

    @Override
    protected void incWorkerNumber() {
        if (DMLScript.STATISTICS) {
            Statistics.incWorkerNumber();
        }
    }

    @Override
    protected void accLocalModelUpdateTime(Timing time) {
        throw new NotImplementedException();
    }

    @Override
    protected void accBatchIndexingTime(Timing time) {
        throw new NotImplementedException();
    }

    @Override
    protected void accGradientComputeTime(Timing time) {
        throw new NotImplementedException();
    }

    private static class federatedComputeGradientsForNBatches
    extends FederatedUDF {
        private static final long serialVersionUID = -3075901536748794832L;
        int _numBatchesToCompute;
        boolean _localUpdate;
        int _localStartBatchNum;

        protected federatedComputeGradientsForNBatches(long[] inIDs, int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) {
            super(inIDs);
            this._numBatchesToCompute = numBatchesToCompute;
            this._localUpdate = localUpdate;
            this._localStartBatchNum = localStartBatchNum;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            Timing tGradients = new Timing(true);
            MatrixObject features = (MatrixObject)data[0];
            MatrixObject labels = (MatrixObject)data[1];
            ListObject model = (ListObject)data[2];
            long batchSize = ((IntObject)ec.getVariable("1701-NCC-batch_size")).getLongValue();
            long dataSize = ((IntObject)ec.getVariable("1701-NCC-data_size")).getLongValue();
            int possibleBatchesPerLocalEpoch = (int)((IntObject)ec.getVariable("1701-NCC-poss_batches_local")).getLongValue();
            String namespace = ((StringObject)ec.getVariable("1701-NCC-namespace")).getStringValue();
            String gradientsFunc = ((StringObject)ec.getVariable("1701-NCC-gradients_fname")).getStringValue();
            String aggFunc = ((StringObject)ec.getVariable("1701-NCC-aggregation_fname")).getStringValue();
            boolean opt = !ec.getProgram().containsFunctionProgramBlock(namespace, gradientsFunc, false);
            FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunc, opt);
            ArrayList<DataIdentifier> inputs = func.getInputParams();
            ArrayList<DataIdentifier> outputs = func.getOutputParams();
            CPOperand[] boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
            ArrayList outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
            FunctionCallCPInstruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunc, opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function");
            DataIdentifier gradientsOutput = outputs.get(0);
            FunctionCallCPInstruction aggregationInstruction = null;
            DataIdentifier aggregationOutput = null;
            if (this._localUpdate && this._numBatchesToCompute > 1) {
                func = ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
                inputs = func.getInputParams();
                outputs = func.getOutputParams();
                boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
                outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
                aggregationInstruction = new FunctionCallCPInstruction(namespace, aggFunc, opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function");
                aggregationOutput = outputs.get(0);
            }
            ListObject accGradients = null;
            int currentLocalBatchNumber = this._localStartBatchNum;
            ec.setVariable("model", model);
            for (int batchCounter = 0; batchCounter < this._numBatchesToCompute; ++batchCounter) {
                int localBatchNum = FederatedPSControlThread.getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch);
                long begin = (long)localBatchNum * batchSize + 1L;
                long end = Math.min((long)(localBatchNum + 1) * batchSize, dataSize);
                MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end);
                MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end);
                ec.setVariable("features", bFeatures);
                ec.setVariable("labels", bLabels);
                ((Instruction)gradientsInstruction).processInstruction(ec);
                ListObject gradients = ec.getListObject(gradientsOutput.getName());
                accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false);
                if (this._localUpdate && batchCounter < this._numBatchesToCompute - 1) {
                    assert (aggregationInstruction != null);
                    ((Instruction)aggregationInstruction).processInstruction(ec);
                    model = ec.getListObject(aggregationOutput.getName());
                    ec.setVariable("model", model);
                    ParamservUtils.cleanupListObject(ec, aggregationOutput.getName());
                }
                ParamservUtils.cleanupListObject(ec, gradientsOutput.getName());
                ParamservUtils.cleanupData(ec, "features");
                ParamservUtils.cleanupData(ec, "labels");
            }
            ParamservUtils.cleanupListObject(ec, ec.getVariable("1701-NCC-model_varid").toString());
            ParamservUtils.cleanupListObject(ec, "model");
            DoubleObject gradientsTime = new DoubleObject(tGradients.stop());
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{accGradients, gradientsTime});
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }

    private static class TeardownFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = -153650281873318969L;

        protected TeardownFederatedWorker() {
            super(new long[0]);
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            ec.removeVariable("1701-NCC-batch_size");
            ec.removeVariable("1701-NCC-data_size");
            ec.removeVariable("1701-NCC-poss_batches_local");
            ec.removeVariable("1701-NCC-namespace");
            ec.removeVariable("1701-NCC-gradients_fname");
            ec.removeVariable("1701-NCC-aggregation_fname");
            ec.removeVariable("1701-NCC-model_varid");
            ParamservUtils.cleanupListObject(ec, "hyperparams");
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }

    private static class SetupFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = -3148991224792675607L;
        private final long _batchSize;
        private final long _dataSize;
        private final int _possibleBatchesPerLocalEpoch;
        private final String _programString;
        private final String _namespace;
        private final String _gradientsFunctionName;
        private final String _aggregationFunctionName;
        private final ListObject _hyperParams;
        private final long _modelVarID;

        protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long modelVarID) {
            super(new long[0]);
            this._batchSize = batchSize;
            this._dataSize = dataSize;
            this._possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch;
            this._programString = programString;
            this._namespace = namespace;
            this._gradientsFunctionName = gradientsFunctionName;
            this._aggregationFunctionName = aggregationFunctionName;
            this._hyperParams = hyperParams;
            this._modelVarID = modelVarID;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            ec.setProgram(ProgramConverter.parseProgram(this._programString, 0));
            ec.setVariable("1701-NCC-batch_size", new IntObject(this._batchSize));
            ec.setVariable("1701-NCC-data_size", new IntObject(this._dataSize));
            ec.setVariable("1701-NCC-poss_batches_local", new IntObject(this._possibleBatchesPerLocalEpoch));
            ec.setVariable("1701-NCC-namespace", new StringObject(this._namespace));
            ec.setVariable("1701-NCC-gradients_fname", new StringObject(this._gradientsFunctionName));
            ec.setVariable("1701-NCC-aggregation_fname", new StringObject(this._aggregationFunctionName));
            ec.setVariable("hyperparams", this._hyperParams);
            ec.setVariable("1701-NCC-model_varid", new IntObject(this._modelVarID));
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

