/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.AColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictLibMatrixMult;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public abstract class APreAgg
extends AColGroupValue {
    private static final long serialVersionUID = 3250955207277128281L;
    private static boolean loggedWarningForDirect = false;

    protected APreAgg(int numRows) {
        super(numRows);
    }

    protected APreAgg(int[] colIndices, int numRows, ADictionary dict, int[] cachedCounts) {
        super(colIndices, numRows, dict, cachedCounts);
    }

    @Override
    public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        if (other instanceof ColGroupEmpty) {
            return;
        }
        if (other instanceof APreAgg) {
            this.tsmmAPreAgg((APreAgg)other, result);
        } else if (other instanceof ColGroupUncompressed) {
            this.tsmmColGroupUncompressed((ColGroupUncompressed)other, result);
        } else {
            throw new DMLCompressionException("Unsupported column group type " + other.getClass().getSimpleName());
        }
    }

    @Override
    public final void leftMultByAColGroup(AColGroup lhs, MatrixBlock result) {
        if (lhs instanceof ColGroupEmpty) {
            return;
        }
        if (lhs instanceof APreAgg) {
            this.leftMultByColGroupValue((APreAgg)lhs, result);
        } else if (lhs instanceof ColGroupUncompressed) {
            this.leftMultByUncompressedColGroup((ColGroupUncompressed)lhs, result);
        } else {
            throw new DMLCompressionException("Not supported left multiplication with A ColGroup of type: " + lhs.getClass().getSimpleName());
        }
    }

    @Deprecated
    private final void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result, int rl, int ru) {
        if (matrix.isEmpty()) {
            return;
        }
        int nCol = this._colIndexes.length;
        int numVals = this.getNumValues();
        MatrixBlock preAgg = new MatrixBlock(ru - rl, numVals, false);
        preAgg.allocateDenseBlock();
        this.preAggregate(matrix, preAgg.getDenseBlockValues(), rl, ru);
        preAgg.recomputeNonZeros();
        MatrixBlock tmpRes = new MatrixBlock(preAgg.getNumRows(), nCol, false);
        this.forceMatrixBlockDictionary();
        MatrixBlock dictM = this._dict.getMBDict(nCol).getMatrixBlock();
        LibMatrixMult.matrixMult(preAgg, dictM, tmpRes);
        APreAgg.addMatrixToResult(tmpRes, result, this._colIndexes, rl, ru);
    }

    public final ADictionary preAggregateThatIndexStructure(APreAgg that) {
        int outputLength = that._colIndexes.length * this.getNumValues();
        Dictionary ret = new Dictionary(new double[outputLength]);
        if (that instanceof ColGroupDDC) {
            this.preAggregateThatDDCStructure((ColGroupDDC)that, ret);
        } else if (that instanceof ColGroupSDCSingleZeros) {
            this.preAggregateThatSDCSingleZerosStructure((ColGroupSDCSingleZeros)that, ret);
        } else if (that instanceof ColGroupSDCZeros) {
            this.preAggregateThatSDCZerosStructure((ColGroupSDCZeros)that, ret);
        } else {
            String cThis = this.getClass().getSimpleName();
            String cThat = that.getClass().getSimpleName();
            throw new NotImplementedException("Not supported pre aggregate using index structure of :" + cThat + " in " + cThis);
        }
        return ret.getMBDict(that._colIndexes.length);
    }

    public void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
        if (m.isInSparseFormat()) {
            this.preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru);
        } else {
            this.preAggregateDense(m, preAgg, rl, ru, 0, m.getNumColumns());
        }
    }

    public abstract void preAggregateDense(MatrixBlock var1, double[] var2, int var3, int var4, int var5, int var6);

    public abstract void preAggregateSparse(SparseBlock var1, double[] var2, int var3, int var4);

    protected abstract void preAggregateThatDDCStructure(ColGroupDDC var1, Dictionary var2);

    protected abstract void preAggregateThatSDCZerosStructure(ColGroupSDCZeros var1, Dictionary var2);

    protected abstract void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros var1, Dictionary var2);

    protected abstract boolean sameIndexStructure(AColGroupCompressed var1);

    public int getPreAggregateSize() {
        return this.getNumValues();
    }

    private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) {
        int[] rightIdx = this._colIndexes;
        int[] leftIdx = lg._colIndexes;
        if (this.sameIndexStructure(lg)) {
            DictLibMatrixMult.TSMMToUpperTriangleScaling(lg._dict, this._dict, leftIdx, rightIdx, this.getCounts(), result);
        } else {
            boolean left = this.shouldPreAggregateLeft(lg);
            if (!loggedWarningForDirect && this.shouldDirectMultiply(lg, leftIdx.length, rightIdx.length, left)) {
                loggedWarningForDirect = true;
                LOG.warn((Object)"Not implemented direct tsmm colgroup");
            }
            if (left) {
                ADictionary lpa = this.preAggregateThatIndexStructure(lg);
                if (lpa != null) {
                    DictLibMatrixMult.TSMMToUpperTriangle(lpa, this._dict, leftIdx, rightIdx, result);
                }
            } else {
                ADictionary rpa = lg.preAggregateThatIndexStructure(this);
                if (rpa != null) {
                    DictLibMatrixMult.TSMMToUpperTriangle(lg._dict, rpa, leftIdx, rightIdx, result);
                }
            }
        }
    }

    private boolean shouldDirectMultiply(APreAgg lg, int nColL, int nColR, boolean leftPreAgg) {
        int lMRows = lg.numRowsToMultiply();
        int rMRows = this.numRowsToMultiply();
        long commonDim = Math.min(lMRows, rMRows);
        long directFLOPS = commonDim * (long)nColL * (long)nColR * 2L;
        long preAggFLOPS = 0L;
        if (leftPreAgg) {
            int nVal = this.getNumValues();
            preAggFLOPS += (long)(nColL * nVal);
            preAggFLOPS += (long)nColL * commonDim;
            preAggFLOPS += (long)(nColR * nColL * nVal);
        } else {
            int nVal = lg.getNumValues();
            preAggFLOPS += (long)(nColR * nVal);
            preAggFLOPS += (long)nColR * commonDim;
            preAggFLOPS += (long)(nColR * nColL * nVal);
        }
        return directFLOPS < preAggFLOPS;
    }

    private void leftMultByColGroupValue(APreAgg lhs, MatrixBlock result) {
        int[] rightIdx = this._colIndexes;
        int[] leftIdx = lhs._colIndexes;
        ADictionary rDict = this._dict;
        ADictionary lDict = lhs._dict;
        boolean sameIdx = this.sameIndexStructure(lhs);
        if (sameIdx && rDict == lDict) {
            DictLibMatrixMult.TSMMDictionaryWithScaling(rDict, this.getCounts(), leftIdx, rightIdx, result);
        } else if (sameIdx) {
            DictLibMatrixMult.MMDictsWithScaling(lDict, rDict, leftIdx, rightIdx, result, this.getCounts());
        } else if (this.shouldPreAggregateLeft(lhs)) {
            ADictionary lhsPA = lhs.preAggregateThatIndexStructure(this);
            if (lhsPA != null) {
                DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result);
            }
        } else {
            ADictionary rhsPA = this.preAggregateThatIndexStructure(lhs);
            if (rhsPA != null) {
                DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result);
            }
        }
    }

    private void leftMultByUncompressedColGroup(ColGroupUncompressed lhs, MatrixBlock result) {
        if (lhs.getData().isEmpty()) {
            return;
        }
        LOG.warn((Object)"Transpose of uncompressed to fit to template need t(a) %*% b");
        MatrixBlock tmp = LibMatrixReorg.transpose(lhs.getData(), InfrastructureAnalyzer.getLocalParallelism());
        int numVals = this.getNumValues();
        MatrixBlock preAgg = new MatrixBlock(tmp.getNumRows(), numVals, false);
        preAgg.allocateDenseBlock();
        this.preAggregate(tmp, preAgg.getDenseBlockValues(), 0, tmp.getNumRows());
        preAgg.recomputeNonZeros();
        MatrixBlock tmpRes = new MatrixBlock(preAgg.getNumRows(), this._colIndexes.length, false);
        MatrixBlock dictM = this._dict.getMBDict(this.getNumCols()).getMatrixBlock();
        LibMatrixMult.matrixMult(preAgg, dictM, tmpRes);
        this.addMatrixToResult(tmpRes, result, lhs._colIndexes);
    }

    private void addMatrixToResult(MatrixBlock tmp, MatrixBlock result, int[] rowIndexes) {
        if (tmp.isEmpty()) {
            return;
        }
        double[] retV = result.getDenseBlockValues();
        int nColRet = result.getNumColumns();
        if (tmp.isInSparseFormat()) {
            SparseBlock sb = tmp.getSparseBlock();
            for (int row = 0; row < rowIndexes.length; ++row) {
                int apos = sb.pos(row);
                int alen = sb.size(row);
                int[] aix = sb.indexes(row);
                double[] avals = sb.values(row);
                int offR = rowIndexes[row] * nColRet;
                for (int i = apos; i < apos + alen; ++i) {
                    int n = offR + this._colIndexes[aix[i]];
                    retV[n] = retV[n] + avals[i];
                }
            }
        } else {
            double[] tmpV = tmp.getDenseBlockValues();
            int nCol = this._colIndexes.length;
            int row = 0;
            int offT = 0;
            while (row < rowIndexes.length) {
                int offR = rowIndexes[row] * nColRet;
                for (int col = 0; col < nCol; ++col) {
                    int n = offR + this._colIndexes[col];
                    retV[n] = retV[n] + tmpV[offT + col];
                }
                ++row;
                offT += nCol;
            }
        }
    }

    private void tsmmColGroupUncompressed(ColGroupUncompressed other, MatrixBlock result) {
        LOG.warn((Object)"Inefficient multiplication with uncompressed column group");
        int nCols = result.getNumColumns();
        MatrixBlock otherMBT = LibMatrixReorg.transpose(other.getData());
        int nRows = otherMBT.getNumRows();
        MatrixBlock tmp = new MatrixBlock(otherMBT.getNumRows(), nCols, false);
        tmp.allocateDenseBlock();
        this.leftMultByMatrix(otherMBT, tmp, 0, nRows);
        double[] r = tmp.getDenseBlockValues();
        double[] resV = result.getDenseBlockValues();
        int otLen = other._colIndexes.length;
        int thisLen = this._colIndexes.length;
        for (int i = 0; i < otLen; ++i) {
            int oid = other._colIndexes[i];
            int offR = i * nCols;
            for (int j = 0; j < thisLen; ++j) {
                DictLibMatrixMult.addToUpperTriangle(nCols, this._colIndexes[j], oid, resV, r[offR + this._colIndexes[j]]);
            }
        }
    }

    private boolean shouldPreAggregateLeft(APreAgg lhs) {
        int lCol;
        double costLeftDense;
        int rCol;
        int nvL = lhs.getNumValues();
        int nvR = this.getNumValues();
        double costRightDense = nvR * (rCol = this._colIndexes.length);
        return costRightDense < (costLeftDense = (double)(nvL * (lCol = lhs._colIndexes.length)));
    }

    public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock ret, int k, int rl, int ru) {
        MatrixBlock preAggCopy = new MatrixBlock();
        preAggCopy.copy(preAgg);
        MatrixBlock tmpResCopy = new MatrixBlock();
        tmpResCopy.copy(tmpRes);
        ADictionary d = this.getDictionary();
        MatrixBlock dict = d.getMBDict(this._colIndexes.length).getMatrixBlock();
        try {
            LibMatrixMult.matrixMult(preAggCopy, dict, tmpResCopy, k);
            APreAgg.addMatrixToResult(tmpResCopy, ret, this._colIndexes, rl, ru);
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed matrix multiply with preAggregate: \n" + preAggCopy + "\n" + dict + "\n" + tmpRes, e);
        }
    }

    private static void addMatrixToResult(MatrixBlock tmp, MatrixBlock result, int[] colIndexes, int rl, int ru) {
        if (tmp.isEmpty()) {
            return;
        }
        double[] retV = result.getDenseBlockValues();
        int nColRet = result.getNumColumns();
        if (tmp.isInSparseFormat()) {
            SparseBlock sb = tmp.getSparseBlock();
            int row = rl;
            int offT = 0;
            while (row < ru) {
                int apos = sb.pos(offT);
                int alen = sb.size(offT);
                int[] aix = sb.indexes(offT);
                double[] avals = sb.values(offT);
                int offR = row * nColRet;
                for (int i = apos; i < apos + alen; ++i) {
                    int n = offR + colIndexes[aix[i]];
                    retV[n] = retV[n] + avals[i];
                }
                ++row;
                ++offT;
            }
        } else {
            double[] tmpV = tmp.getDenseBlockValues();
            int nCol = colIndexes.length;
            int row = rl;
            int offT = 0;
            while (row < ru) {
                int offR = row * nColRet;
                for (int col = 0; col < nCol; ++col) {
                    int n = offR + colIndexes[col];
                    retV[n] = retV[n] + tmpV[offT + col];
                }
                ++row;
                offT += nCol;
            }
        }
    }

    protected abstract int numRowsToMultiply();
}

