/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupDDCFOR
extends AMorphingMMColGroup
implements IFrameOfReferenceGroup {
    private static final long serialVersionUID = -5769772089913918987L;
    protected final AMapToData _data;
    protected final double[] _reference;

    private ColGroupDDCFOR(IColIndex colIndexes, IDictionary dict, double[] reference, AMapToData data, int[] cachedCounts) {
        super(colIndexes, dict, cachedCounts);
        this._data = data;
        this._reference = reference;
    }

    public static AColGroup create(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts, double[] reference) {
        boolean allZero = ColGroupUtils.allZero(reference);
        if (dict == null && allZero) {
            return new ColGroupEmpty(colIndexes);
        }
        if (dict == null) {
            return ColGroupConst.create(colIndexes, reference);
        }
        if (data.getUnique() == 1) {
            return ColGroupConst.create(colIndexes, dict.binOpRight(new BinaryOperator(Plus.getPlusFnObject()), reference));
        }
        if (allZero) {
            return ColGroupDDC.create(colIndexes, dict, data, cachedCounts);
        }
        return new ColGroupDDCFOR(colIndexes, dict, reference, data, cachedCounts);
    }

    public static AColGroup sparsifyFOR(ColGroupDDC g) {
        int nCol = g.getColIndices().size();
        MatrixBlockDictionary mbd = g._dict.getMBDict(nCol);
        if (mbd != null) {
            MatrixBlock mb = mbd.getMatrixBlock();
            double[] ref = ColGroupUtils.extractMostCommonValueInColumns(mb);
            if (ref != null) {
                MatrixBlockDictionary mDict = mbd.binOpRight(new BinaryOperator(Minus.getMinusFnObject()), ref);
                return ColGroupDDCFOR.create(g.getColIndices(), mDict, g._data, g.getCachedCounts(), ref);
            }
            return g;
        }
        throw new NotImplementedException("The dictionary was empty... highly unlikely");
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.DDCFOR;
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this._dict.getValue(this._data.getIndex(r), colIdx, this._colIndexes.size()) + this._reference[colIdx];
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        for (int rix = rl; rix < ru; ++rix) {
            int n = rix;
            c[n] = c[n] + preAgg[this._data.getIndex(rix)];
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], preAgg[this._data.getIndex(i)]);
        }
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this._data.getCounts(counts);
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.DDCFOR;
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += this._data.getInMemorySize();
        return size += (long)(8 * this._colIndexes.size());
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.executeScalar(this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            IDictionary newDict = this._dict.applyScalarOp(op);
            return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
        }
        IDictionary newDict = this._dict.applyScalarOpWithReference(op, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        double[] newRef = ColGroupUtils.unaryOperator(op, this._reference);
        IDictionary newDict = this._dict.applyUnaryOpWithReference(op, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(v[this._colIndexes.get(i)], this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            IDictionary newDict = this._dict.binOpLeft(op, v, this._colIndexes);
            return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
        }
        IDictionary newDict = this._dict.binOpLeftWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(this._reference[i], v[this._colIndexes.get(i)]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupDDCFOR.create(this._colIndexes, this._dict, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            IDictionary newDict = this._dict.binOpRight(op, v, this._colIndexes);
            return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
        }
        IDictionary newDict = this._dict.binOpRightWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._data.write(out);
        for (double d : this._reference) {
            out.writeDouble(d);
        }
    }

    public static ColGroupDDCFOR read(DataInput in) throws IOException {
        IColIndex cols = ColIndexFactory.read(in);
        IDictionary dict = DictionaryFactory.read(in);
        AMapToData data = MapToFactory.readIn(in);
        double[] ref = ColGroupIO.readDoubleArray(cols.size(), in);
        return new ColGroupDDCFOR(cols, dict, ref, data, null);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ret += this._data.getExactSizeOnDisk();
        return ret += (long)(8 * this._colIndexes.size());
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nVals = this.getNumValues();
        int nCols = this.getNumCols();
        return e.getCost(nRows, nRows, nCols, nVals, this._dict.getSparsity());
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        IDictionary newDict = this._dict.replaceWithReference(pattern, replace, this._reference);
        boolean patternInReference = false;
        for (double d : this._reference) {
            if (pattern != d) continue;
            patternInReference = true;
            break;
        }
        if (patternInReference) {
            double[] nRef = new double[this._reference.length];
            for (int i = 0; i < this._reference.length; ++i) {
                nRef[i] = Util.eq(pattern, this._reference[i]) ? replace : this._reference[i];
            }
            return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), nRef);
        }
        return ColGroupDDCFOR.create(this._colIndexes, newDict, this._data, this.getCachedCounts(), this._reference);
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregateWithReference(c, builtin, this._reference, false);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateColsWithReference(c, builtin, this._colIndexes, this._reference, false);
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        super.computeSum(c, nRows);
        double refSum = ColGroupUtils.refSum(this._reference);
        c[0] = c[0] + refSum * (double)nRows;
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        super.computeColSums(c, nRows);
        for (int i = 0; i < this._colIndexes.size(); ++i) {
            int n = this._colIndexes.get(i);
            c[n] = c[n] + this._reference[i] * (double)nRows;
        }
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSqWithReference(this.getCounts(), this._reference);
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict.colSumSqWithReference(c, this.getCounts(), this._colIndexes, this._reference);
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithReference(this._reference);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithReference(this._reference);
    }

    @Override
    protected double[] preAggProductRows() {
        return this._dict.productAllRowsToDoubleWithReference(this._reference);
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithReference(builtin, this._reference);
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        this._dict.productWithReference(c, this.getCounts(), this._reference, 0);
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        for (int rix = rl; rix < ru; ++rix) {
            int n = rix;
            c[n] = c[n] * preAgg[this._data.getIndex(rix)];
        }
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        this._dict.colProductWithReference(c, this.getCounts(), this._colIndexes, this._reference);
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) {
        IDictionary retDict = this._dict.sliceOutColumnRange(idStart, idEnd, this._colIndexes.size());
        double[] newDef = new double[idEnd - idStart];
        int i = idStart;
        int j = 0;
        while (i < idEnd) {
            newDef[j] = this._reference[i];
            ++i;
            ++j;
        }
        return ColGroupDDCFOR.create(outputCols, retDict, this._data, this.getCounts(), newDef);
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        IColIndex retIndexes = ColIndexFactory.create(1);
        if (this._colIndexes.size() == 1) {
            return ColGroupDDCFOR.create(retIndexes, this._dict, this._data, this.getCounts(), this._reference);
        }
        double[] newDef = new double[]{this._reference[idx]};
        IDictionary retDict = this._dict.sliceOutColumnRange(idx, idx + 1, this._colIndexes.size());
        return ColGroupDDCFOR.create(retIndexes, retDict, this._data, this.getCounts(), newDef);
    }

    @Override
    public boolean containsValue(double pattern) {
        if (Double.isNaN(pattern) || Double.isInfinite(pattern)) {
            return ColGroupUtils.containsInfOrNan(pattern, this._reference) || this._dict.containsValue(pattern);
        }
        return this._dict.containsValueWithReference(pattern, this._reference);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        return (long)this._colIndexes.size() * (long)nRows;
    }

    @Override
    public AColGroup extractCommon(double[] constV) {
        for (int i = 0; i < this._colIndexes.size(); ++i) {
            int n = this._colIndexes.get(i);
            constV[n] = constV[n] + this._reference[i];
        }
        return ColGroupDDC.create(this._colIndexes, this._dict, this._data, this.getCounts());
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        int def = (int)this._reference[0];
        IDictionary d = this._dict.rexpandColsWithReference(max, ignore, cast, def);
        if (d == null) {
            if (def <= 0 || def > max) {
                return ColGroupEmpty.create(max);
            }
            double[] retDef = new double[max];
            retDef[def - 1] = 1.0;
            return ColGroupConst.create(retDef);
        }
        IColIndex outCols = ColIndexFactory.create(max);
        if (def <= 0) {
            if (ignore) {
                return ColGroupDDC.create(outCols, d, this._data, this.getCachedCounts());
            }
            throw new DMLRuntimeException("Invalid content of zero in rexpand");
        }
        if (def > max) {
            return ColGroupDDC.create(outCols, d, this._data, this.getCachedCounts());
        }
        double[] retDef = new double[max];
        retDef[def - 1] = 1.0;
        return ColGroupDDCFOR.create(outCols, d, this._data, this.getCachedCounts(), retDef);
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = this._dict.centralMomentWithReference(op.fn, this.getCounts(), this._reference[0], nRows);
        return ret;
    }

    @Override
    public double[] getCommon() {
        return this._reference;
    }

    @Override
    protected AColGroup allocateRightMultiplicationCommon(double[] common, IColIndex colIndexes, IDictionary preAgg) {
        return ColGroupDDCFOR.create(colIndexes, preAgg, this._data, this.getCachedCounts(), common);
    }

    @Override
    public AColGroup sliceRows(int rl, int ru) {
        AMapToData sliceMap = this._data.slice(rl, ru);
        return new ColGroupDDCFOR(this._colIndexes, this._dict, this._reference, sliceMap, null);
    }

    @Override
    protected AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary) {
        return ColGroupDDCFOR.create(colIndexes, newDictionary, this._data, this.getCachedCounts(), this._reference);
    }

    @Override
    public AColGroup append(AColGroup g) {
        if (g instanceof ColGroupDDCFOR && g.getColIndices().equals(this._colIndexes)) {
            ColGroupDDCFOR gDDC = (ColGroupDDCFOR)g;
            if (Arrays.equals(this._reference, gDDC._reference) && gDDC._dict.equals(this._dict)) {
                AMapToData nd = this._data.append(gDDC._data);
                return ColGroupDDCFOR.create(this._colIndexes, this._dict, nd, null, this._reference);
            }
        }
        return null;
    }

    @Override
    public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) {
        throw new NotImplementedException();
    }

    @Override
    public ICLAScheme getCompressionScheme() {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup recompress() {
        throw new NotImplementedException();
    }

    @Override
    public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
        IEncode enc = this.getEncoding();
        EstimationFactors ef = new EstimationFactors(this.getNumValues(), this._data.size(), this._data.size(), this._dict.getSparsity());
        return new CompressedSizeInfoColGroup(this._colIndexes, ef, this.estimateInMemorySize(), this.getCompType(), enc);
    }

    @Override
    public IEncode getEncoding() {
        return EncodingFactory.create(this._data);
    }

    @Override
    public boolean sameIndexStructure(AColGroupCompressed that) {
        return that instanceof ColGroupDDCFOR && ((ColGroupDDCFOR)that)._data == this._data;
    }

    @Override
    protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
        throw new NotImplementedException();
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s ", "Data: "));
        sb.append(this._data);
        sb.append(String.format("\n%15s", "Reference:"));
        sb.append(Arrays.toString(this._reference));
        return sb.toString();
    }
}

