/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNNIm2Col;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNNRotate180;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.stats.NativeStatistics;

public class LibMatrixDNNConv2d {
    public static ArrayList<Callable<Long>> getConv2dWorkers(DnnParameters params) {
        boolean applyNative;
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k / 2.0);
        MatrixBlock in1 = params.input1;
        boolean isEmptyDenseInput = !in1.isInSparseFormat() && in1.denseBlock == null;
        boolean isTransPref = in1.sparse && !params.input2.sparse && !params.output.sparse && MatrixBlock.evalSparseFormatInMemory(in1.clen, in1.rlen, in1.nonZeros);
        boolean bl = applyNative = LibMatrixDNNConv2d.isEligibleForConv2dSparse(params) && (isEmptyDenseInput || !isTransPref);
        if (applyNative) {
            NativeStatistics.incrementNumSparseConv2dCalls();
        }
        if (!applyNative && !isEmptyDenseInput && isTransPref) {
            params.input2 = LibMatrixReorg.transpose(params.input2, new MatrixBlock(params.input2.clen, params.input2.rlen, false), k);
        }
        int i = 0;
        while (i * taskSize < params.N) {
            if (applyNative) {
                ret.add(new SparseNativeConv2d(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput && isTransPref) {
                ret.add(new LoopedIm2ColConv2dTransAllChan(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput) {
                ret.add(new LoopedIm2ColConv2dAllChan(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getConv2dBackwardFilterWorkers(DnnParameters params) {
        boolean applyNative;
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null || !params.input2.isInSparseFormat() && params.input2.denseBlock == null;
        boolean bl = applyNative = LibMatrixDNNConv2d.isEligibleForConv2dBackwardFilterSparseDense(params) && !params.input2.isInSparseFormat();
        if (applyNative) {
            NativeStatistics.incrementNumSparseConv2dBwdFilterCalls();
        }
        int i = 0;
        while (i * taskSize < params.N) {
            if (applyNative) {
                ret.add(new SparseNativeConv2dBackwardFilterDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (params.input2.sparse && params.input1.getSparsity() > params.input2.getSparsity()) {
                ret.add(new Conv2dBackwardFilterTrans(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput) {
                ret.add(new Conv2dBackwardFilter(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getConv2dBackwardDataWorkers(DnnParameters params) {
        boolean applyNative;
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null || !params.input2.isInSparseFormat() && params.input2.denseBlock == null;
        boolean bl = applyNative = LibMatrixDNNConv2d.isEligibleForConv2dBackwardDataDense(params) && !params.input2.isInSparseFormat();
        if (applyNative) {
            NativeStatistics.incrementNumSparseConv2dBwdDataCalls();
        }
        int i = 0;
        while (i * taskSize < params.N) {
            if (applyNative) {
                ret.add(new SparseNativeConv2dBackwardDataDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput) {
                ret.add(new Conv2dBackwardData(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void inplaceAdd(double[] a, DnnParameters params) {
        DenseBlock denseBlock = params.output.denseBlock;
        synchronized (denseBlock) {
            LibMatrixMult.vectAdd(a, params.output.getDenseBlockValues(), 0, 0, a.length);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void inplaceTransAdd(double[] a, DnnParameters params) {
        DenseBlock denseBlock = params.output.denseBlock;
        synchronized (denseBlock) {
            double[] c = params.output.getDenseBlockValues();
            int CRS = params.C * params.R * params.S;
            int K2 = params.K;
            int blocksizeIJ = 128;
            for (int bi = 0; bi < CRS; bi += 128) {
                for (int bj = 0; bj < K2; bj += 128) {
                    int bimin = Math.min(bi + 128, CRS);
                    int bjmin = Math.min(bj + 128, K2);
                    int i = bi;
                    int aix = bi * K2;
                    while (i < bimin) {
                        int j = bj;
                        int cix = i + bj * CRS;
                        while (j < bjmin) {
                            int n = cix;
                            c[n] = c[n] + a[aix + j];
                            ++j;
                            cix += CRS;
                        }
                        ++i;
                        aix += K2;
                    }
                }
            }
        }
    }

    private static void getRowInDenseFormat(MatrixBlock input, int n, double[] ret) {
        if (input.getNumColumns() != ret.length) {
            throw new DMLRuntimeException("Invalid parameters");
        }
        if (input.isInSparseFormat()) {
            Arrays.fill(ret, 0.0);
            if (!input.sparseBlock.isEmpty(n)) {
                int apos = input.sparseBlock.pos(n);
                int alen = input.sparseBlock.size(n);
                int[] aix = input.sparseBlock.indexes(n);
                double[] avals = input.sparseBlock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    ret[aix[j]] = avals[j];
                }
            }
        } else {
            System.arraycopy(input.getDenseBlockValues(), n * input.getNumColumns(), ret, 0, input.getNumColumns());
        }
    }

    private static void addBias(int r, double[] out, double[] bias, int K2, int PQ) {
        int k = 0;
        int cix = r * K2 * PQ;
        while (k < K2) {
            LibMatrixMult.vectAddInPlace(bias[k], out, cix, PQ);
            ++k;
            cix += PQ;
        }
    }

    private static boolean isEligibleForConv2dBackwardFilterSparseDense(DnnParameters params) {
        return false;
    }

    private static boolean isEligibleForConv2dSparse(DnnParameters params) {
        return false;
    }

    private static boolean isEligibleForConv2dBackwardDataDense(DnnParameters params) {
        return false;
    }

    private static class Conv2dBackwardFilterTrans
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final DnnParameters _params;

        public Conv2dBackwardFilterTrans(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock dout = this._params.input2;
            MatrixBlock im2ColOutBlock = new MatrixBlock(PQ, CRS, this._params.input1.sparse).allocateBlock();
            LibMatrixDNNIm2Col.preallocateSparseOutput(this._params.input1, im2ColOutBlock);
            MatrixBlock outRotate = new MatrixBlock(K2, PQ, dout.sparse).allocateBlock();
            MatrixBlock outMM = new MatrixBlock(K2, CRS, false).allocateBlock();
            LibMatrixDNNRotate180.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180.Rotate180Worker.getWorker(dout, outRotate, this._params, true, true);
            double[] partRet = new double[CRS * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                rotate180Worker.execute(n, 0);
                LibMatrixDNNIm2Col.im2col(this._params.input1, im2ColOutBlock, n, this._params, true);
                outMM.reset(K2, CRS, false);
                LibMatrixDNNHelper.singleThreadedMatMult(outRotate, im2ColOutBlock, outMM, !outRotate.sparse, !im2ColOutBlock.sparse, this._params);
                if (outMM.isEmptyBlock()) continue;
                LibMatrixMult.vectAdd(outMM.getDenseBlockValues(), partRet, 0, 0, K2 * CRS);
            }
            LibMatrixDNNConv2d.inplaceAdd(partRet, this._params);
            return 0L;
        }
    }

    private static class Conv2dBackwardFilter
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final DnnParameters _params;

        public Conv2dBackwardFilter(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock dout = this._params.input2;
            MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, this._params.input1.sparse).allocateBlock();
            LibMatrixDNNIm2Col.preallocateSparseOutput(this._params.input1, im2ColOutBlock);
            MatrixBlock outRotate = new MatrixBlock(PQ, K2, dout.sparse);
            MatrixBlock outMM = new MatrixBlock(CRS, K2, false);
            outRotate.allocateBlock();
            LibMatrixDNNRotate180.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180.Rotate180Worker.getWorker(dout, outRotate, this._params, true, false);
            double[] partRet = new double[CRS * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                rotate180Worker.execute(n, 0);
                LibMatrixDNNIm2Col.im2col(this._params.input1, im2ColOutBlock, n, this._params, false);
                outMM.reset(CRS, K2, false);
                LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, outRotate, outMM, !im2ColOutBlock.sparse, !outRotate.sparse, this._params);
                if (outMM.isEmptyBlock()) continue;
                LibMatrixMult.vectAdd(outMM.getDenseBlockValues(), partRet, 0, 0, K2 * CRS);
            }
            LibMatrixDNNConv2d.inplaceTransAdd(partRet, this._params);
            return 0L;
        }
    }

    private static class SparseNativeConv2dBackwardFilterDense
    implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final DnnParameters _params;

        public SparseNativeConv2dBackwardFilterDense(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int CRS = this._params.C * this._params.R * this._params.S;
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            MatrixBlock dout_n = new MatrixBlock(PQ, K2, false);
            dout_n.allocateBlock();
            LibMatrixDNNRotate180.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180.Rotate180Worker.getWorker(this._params.input2, dout_n, this._params, true, false);
            double[] ldout_n = dout_n.getDenseBlockValues();
            double[] partRet = new double[CRS * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                rotate180Worker.execute(n, 0);
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dBackwardFilterSparseDense(apos, alen, aix, avals, ldout_n, partRet, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
            }
            LibMatrixDNNConv2d.inplaceTransAdd(partRet, this._params);
            return 0L;
        }
    }

    private static class Conv2dBackwardData
    implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final DnnParameters _params;

        public Conv2dBackwardData(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock filter = this._params.input1;
            MatrixBlock dout = this._params.input2;
            MatrixBlock outRotate = new MatrixBlock(PQ, K2, dout.sparse);
            MatrixBlock outMM = new MatrixBlock(PQ, CRS, false);
            outRotate.allocateBlock();
            LibMatrixDNNRotate180.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180.Rotate180Worker.getWorker(dout, outRotate, this._params, true, false);
            for (int n = this._rl; n < this._ru; ++n) {
                rotate180Worker.execute(n, 0);
                outMM.reset(PQ, CRS, false);
                LibMatrixDNNHelper.singleThreadedMatMult(outRotate, filter, outMM, !outRotate.sparse, false, this._params);
                LibMatrixDNNIm2Col.col2imOverSingleImage(n, outMM, this._params);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class SparseNativeConv2dBackwardDataDense
    implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final DnnParameters _params;

        public SparseNativeConv2dBackwardDataDense(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int CHW = this._params.C * this._params.H * this._params.W;
            double[] ret = new double[CHW];
            double[] filterArr = this._params.input1.getDenseBlockValues();
            double[] dout_n = new double[this._params.P * this._params.Q * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                LibMatrixDNNConv2d.getRowInDenseFormat(this._params.input2, n, dout_n);
                if (n > this._rl) {
                    Arrays.fill(ret, 0.0);
                }
                NativeHelper.conv2dBackwardDataDense(filterArr, dout_n, ret, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
                System.arraycopy(ret, 0, this._params.output.getDenseBlockValues(), n * CHW, CHW);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class SparseNativeConv2d
    implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final DnnParameters _params;

        public SparseNativeConv2d(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int KPQ = this._params.K * this._params.P * this._params.Q;
            double[] temp = new double[KPQ];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dSparse(apos, alen, aix, avals, this._params.input2.getDenseBlockValues(), temp, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
                System.arraycopy(temp, 0, this._params.output.getDenseBlockValues(), n * KPQ, KPQ);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class LoopedIm2ColConv2dTransAllChan
    extends LoopedIm2ColConv2dAllChan {
        public LoopedIm2ColConv2dTransAllChan(int rl, int ru, DnnParameters params) {
            super(rl, ru, params);
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock outIm2col = new MatrixBlock(PQ, CRS, this._params.input1.sparse).allocateBlock();
            LibMatrixDNNIm2Col.preallocateSparseOutput(this._params.input1, outIm2col);
            MatrixBlock outMM = new MatrixBlock(PQ, K2, false);
            for (int n = this._rl; n < this._ru; ++n) {
                LibMatrixDNNIm2Col.im2col(this._params.input1, outIm2col, n, this._params, true);
                outMM.reset(outMM.rlen, outMM.clen, false);
                LibMatrixDNNHelper.singleThreadedMatMult(outIm2col, this._params.input2, outMM, false, false, this._params);
                LoopedIm2ColConv2dTransAllChan.partialCopyTrans(outMM, this._params.output, n * K2 * PQ, K2, PQ);
                if (this._params.bias == null) continue;
                LibMatrixDNNConv2d.addBias(n, this._params.output.getDenseBlockValues(), this._params.bias.getDenseBlockValues(), K2, PQ);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        private static void partialCopyTrans(MatrixBlock src, MatrixBlock dest, int destPos, int K2, int PQ) {
            if (src.isEmptyBlock()) {
                return;
            }
            if (src.isInSparseFormat()) {
                SparseBlock sblock = src.sparseBlock;
                double[] c = dest.getDenseBlockValues();
                for (int i = 0; i < src.getNumRows(); ++i) {
                    if (sblock.isEmpty(i)) continue;
                    int apos = sblock.pos(i);
                    int alen = sblock.size(i);
                    int[] aix = sblock.indexes(i);
                    double[] avals = sblock.values(i);
                    int desPosK = destPos + i;
                    for (int j = apos; j < apos + alen; ++j) {
                        c[desPosK + aix[j] * PQ] = avals[j];
                    }
                }
            } else {
                double[] a = src.getDenseBlockValues();
                double[] c = dest.getDenseBlockValues();
                int blocksizeIJ = 128;
                for (int bi = 0; bi < PQ; bi += 128) {
                    for (int bj = 0; bj < K2; bj += 128) {
                        int bimin = Math.min(bi + 128, PQ);
                        int bjmin = Math.min(bj + 128, K2);
                        int i = bi;
                        int aix = bi * K2 + bj;
                        int cix = bj * PQ + bi;
                        while (i < bimin) {
                            LibMatrixReorg.transposeRow(a, c, aix, destPos + cix, PQ, bjmin - bj);
                            ++i;
                            aix += K2;
                            ++cix;
                        }
                    }
                }
            }
        }
    }

    private static class LoopedIm2ColConv2dAllChan
    implements Callable<Long> {
        protected final int _rl;
        protected final int _ru;
        protected final DnnParameters _params;

        public LoopedIm2ColConv2dAllChan(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K2 = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, this._params.input1.sparse).allocateBlock();
            LibMatrixDNNIm2Col.preallocateSparseOutput(this._params.input1, outIm2col);
            MatrixBlock outMM = new MatrixBlock(K2, PQ, this._params.output.sparse);
            for (int n = this._rl; n < this._ru; ++n) {
                LibMatrixDNNIm2Col.im2col(this._params.input1, outIm2col, n, this._params, false);
                outMM.reset(outMM.rlen, outMM.clen, this._params.output.sparse);
                LibMatrixDNNHelper.singleThreadedMatMult(this._params.input2, outIm2col, outMM, false, true, this._params);
                LoopedIm2ColConv2dAllChan.partialCopy1(outMM, this._params.output, n, K2, PQ);
                if (this._params.bias == null) continue;
                LibMatrixDNNConv2d.addBias(n, this._params.output.getDenseBlockValues(), this._params.bias.getDenseBlockValues(), K2, PQ);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        private static void partialCopy1(MatrixBlock src, MatrixBlock dest, int r, int K2, int PQ) {
            if (src.isEmptyBlock()) {
                return;
            }
            if (src.sparse) {
                SparseBlock srcBlock = src.sparseBlock;
                SparseBlock sdestBlock = dest.sparseBlock;
                double[] ddestBlock = dest.getDenseBlockValues();
                for (int k = 0; k < src.getNumRows(); ++k) {
                    if (srcBlock.isEmpty(k)) continue;
                    int apos = srcBlock.pos(k);
                    int alen = srcBlock.size(k);
                    int[] aix = srcBlock.indexes(k);
                    double[] avals = srcBlock.values(k);
                    if (dest.sparse) {
                        sdestBlock.setIndexRange(r, 0, K2 * PQ, avals, aix, apos, alen);
                        continue;
                    }
                    int desPosK = r + k * PQ;
                    for (int j = apos; j < apos + alen; ++j) {
                        ddestBlock[desPosK + aix[j]] = avals[j];
                    }
                }
            } else if (dest.sparse) {
                dest.getSparseBlock().setIndexRange(r, 0, K2 * PQ, src.getDenseBlockValues(), 0, K2 * PQ);
            } else {
                System.arraycopy(src.getDenseBlockValues(), 0, dest.getDenseBlockValues(), r * K2 * PQ, K2 * PQ);
            }
        }
    }
}

