/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.UnaryFEDInstruction;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class UnaryMatrixFEDInstruction
extends UnaryFEDInstruction {
    protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) {
        super(FEDInstruction.FEDType.Unary, op, in, out, opcode, instr);
    }

    public static boolean isValidOpcode(String opcode) {
        return !LibCommonsMath.isSupportedUnaryOperation(opcode);
    }

    public static UnaryMatrixFEDInstruction parseInstruction(String str) {
        CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (parts.length == 5 && (opcode.equalsIgnoreCase("exp") || opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
            in.split(parts[1]);
            out.split(parts[2]);
            Builtin func = Builtin.getBuiltinFnObject(opcode);
            if (Arrays.asList("ucumk+", "ucum*", "ucumk+*", "ucummin", "ucummax", "exp", "log", "sigmoid").contains(opcode)) {
                UnaryOperator op = new UnaryOperator(func, Integer.parseInt(parts[3]), Boolean.parseBoolean(parts[4]));
                return new UnaryMatrixFEDInstruction(op, in, out, opcode, str);
            }
            return new UnaryMatrixFEDInstruction(null, in, out, opcode, str);
        }
        opcode = UnaryMatrixFEDInstruction.parseUnaryInstruction(str, in, out);
        return new UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        if (this.getOpcode().startsWith("ucum") && mo1.isFederated(FederationMap.FType.ROW)) {
            this.processCumulativeInstruction(ec, mo1);
        } else {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()});
            mo1.getFedMapping().execute(this.getTID(), true, fr1);
            this.setOutputFedMapping(ec, mo1, fr1.getID());
        }
    }

    public void processCumulativeInstruction(ExecutionContext ec, MatrixObject mo1) {
        MatrixObject out;
        String opcode = this.getOpcode();
        if (opcode.equalsIgnoreCase("ucumk+*")) {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()});
            FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
            Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), true, fr1, fr2);
            out = this.setOutputFedMapping(ec, mo1, fr1.getID());
            MatrixBlock scalingValues = this.getScalars(mo1, tmp);
            this.setScalingValues(ec, mo1, out, scalingValues);
        } else {
            String colAgg = opcode.replace("ucum", "uac");
            String agg2 = opcode.replace(opcode.contains("ucumk") ? "ucumk" : "ucum", "");
            double init = opcode.equalsIgnoreCase("ucumk+") ? 0.0 : (opcode.equalsIgnoreCase("ucum*") ? 1.0 : (opcode.equalsIgnoreCase("ucummin") ? Double.MAX_VALUE : -1.7976931348623157E308));
            Future<FederatedResponse>[] tmp = this.modifyAndGetInstruction(colAgg, mo1);
            MatrixBlock scalingValues = UnaryMatrixFEDInstruction.getResultBlock(tmp, (int)mo1.getNumColumns(), opcode, init);
            out = ec.getMatrixObject(this.output);
            this.setScalingValues(agg2, ec, mo1, out, scalingValues, init);
        }
        this.processCumulative(out);
    }

    private Future<FederatedResponse>[] modifyAndGetInstruction(String newInst, MatrixObject mo1) {
        String modifiedInstString = InstructionUtils.replaceOperand(this.instString, 1, newInst);
        FederatedRequest fr1 = FederationUtils.callInstruction(modifiedInstString, this.output, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()});
        FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
        return mo1.getFedMapping().execute(this.getTID(), true, fr1, fr2);
    }

    private void processCumulative(MatrixObject out) {
        String modifiedInstString = InstructionUtils.replaceOperand(this.instString, 2, InstructionUtils.createOperand(this.output));
        FederatedRequest fr4 = FederationUtils.callInstruction(modifiedInstString, this.output, out.getFedMapping().getID(), new CPOperand[]{this.output}, new long[]{out.getFedMapping().getID()});
        out.getFedMapping().execute(this.getTID(), true, fr4);
        out.setFedMapping(out.getFedMapping().copyWithNewID(fr4.getID()));
        if (this.getOpcode().equalsIgnoreCase("ucumk+*")) {
            out.getDataCharacteristics().set(out.getNumRows(), 1L, (int)out.getBlocksize());
            for (int i = 0; i < out.getFedMapping().getFederatedRanges().length; ++i) {
                out.getFedMapping().getFederatedRanges()[i].setEndDim(1, 1L);
            }
        } else {
            out.getDataCharacteristics().set(out.getNumRows(), out.getNumColumns(), (int)out.getBlocksize());
        }
    }

    private static MatrixBlock getResultBlock(Future<FederatedResponse>[] tmp, int cols, String opcode, double init) {
        MatrixBlock res = new MatrixBlock(tmp.length, cols, init);
        for (int i = 0; i < tmp.length - 1; ++i) {
            try {
                res.copy(i + 1, i + 1, 0, cols - 1, (MatrixBlock)tmp[i].get().getData()[0], true);
                continue;
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
            }
        }
        return res.unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject(opcode)), new MatrixBlock());
    }

    private MatrixBlock getScalars(MatrixObject mo1, Future<FederatedResponse>[] tmp) {
        MatrixBlock[] aggRes = this.getAggMatrices(mo1);
        MatrixBlock prod = aggRes[0];
        MatrixBlock firstValues = aggRes[1];
        for (int i = 0; i < tmp.length; ++i) {
            try {
                MatrixBlock curr = (MatrixBlock)tmp[i].get().getData()[0];
                prod.setValue(i, 0, curr.getValue(curr.getNumRows() - 1, 0));
                continue;
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
            }
        }
        MatrixBlock a = new MatrixBlock(tmp.length, 1, 0.0);
        a.copy(1, a.getNumRows() - 1, 0, 0, prod.unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*")), new MatrixBlock()).slice(0, prod.getNumRows() - 2), true);
        MatrixBlock B = firstValues.slice(0, firstValues.getNumRows() - 1, 1, 1).binaryOperations(InstructionUtils.parseBinaryOperator("*"), a, new MatrixBlock());
        return B.binaryOperationsInPlace(InstructionUtils.parseBinaryOperator("+"), firstValues.slice(0, firstValues.getNumRows() - 1, 0, 0));
    }

    private MatrixBlock[] getAggMatrices(MatrixObject mo1) {
        Future<FederatedResponse>[] tmp = this.modifyAndGetInstruction("ucum*", mo1);
        MatrixBlock prod = new MatrixBlock(tmp.length, 2, 0.0);
        MatrixBlock firstValues = new MatrixBlock(tmp.length, 2, 0.0);
        for (int i = 0; i < tmp.length; ++i) {
            try {
                MatrixBlock curr = (MatrixBlock)tmp[i].get().getData()[0];
                prod.setValue(i, 1, curr.getValue(curr.getNumRows() - 1, 1));
                firstValues.copy(i, i, 0, 1, curr.slice(0, 0), true);
                continue;
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
            }
        }
        return new MatrixBlock[]{prod, firstValues};
    }

    private void setScalingValues(ExecutionContext ec, MatrixObject mo1, MatrixObject out, MatrixBlock scalingValues) {
        MatrixBlock condition = new MatrixBlock((int)mo1.getNumRows(), (int)mo1.getNumColumns(), 1.0);
        MatrixBlock mb2 = new MatrixBlock((int)mo1.getNumRows(), (int)mo1.getNumColumns(), 0.0);
        for (int i = 0; i < scalingValues.getNumRows() - 1; ++i) {
            int step = (int)mo1.getFedMapping().getFederatedRanges()[i + 1].getBeginDims()[0];
            condition.setValue(step, 0, 0.0);
            mb2.setValue(step, 0, scalingValues.getValue(i + 1, 0));
        }
        MatrixObject cond = ExecutionContext.createMatrixObject(condition);
        long condID = FederationUtils.getNextFedDataID();
        ec.setVariable(String.valueOf(condID), cond);
        MatrixObject mo2 = ExecutionContext.createMatrixObject(mb2);
        long varID2 = FederationUtils.getNextFedDataID();
        ec.setVariable(String.valueOf(varID2), mo2);
        CPOperand opCond = new CPOperand(String.valueOf(condID), Types.ValueType.FP64, Types.DataType.MATRIX);
        CPOperand op2 = new CPOperand(String.valueOf(varID2), Types.ValueType.FP64, Types.DataType.MATRIX);
        String ternaryInstString = InstructionUtils.constructTernaryString(this.instString, opCond, this.input1, op2, this.output);
        FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(cond, false);
        FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo2, false);
        FederatedRequest fr3 = FederationUtils.callInstruction(ternaryInstString, this.output, new CPOperand[]{this.input1, opCond, op2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID(), fr2[0].getID()});
        mo1.getFedMapping().execute(this.getTID(), true, fr1, fr2, new FederatedRequest[]{fr3});
        out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
        ec.removeVariable(opCond.getName());
        ec.removeVariable(op2.getName());
    }

    private void setScalingValues(String opcode, ExecutionContext ec, MatrixObject mo1, MatrixObject out, MatrixBlock scalingValues, double init) {
        MatrixBlock mb2 = new MatrixBlock((int)mo1.getNumRows(), (int)mo1.getNumColumns(), init);
        for (int i = 1; i < scalingValues.getNumRows(); ++i) {
            int step = (int)mo1.getFedMapping().getFederatedRanges()[i].getBeginDims()[0];
            mb2.copy(step, step, 0, (int)(mo1.getNumColumns() - 1L), scalingValues.slice(i, i), true);
        }
        MatrixObject mo2 = ExecutionContext.createMatrixObject(mb2);
        long varID2 = FederationUtils.getNextFedDataID();
        ec.setVariable(String.valueOf(varID2), mo2);
        CPOperand op2 = new CPOperand(String.valueOf(varID2), Types.ValueType.FP64, Types.DataType.MATRIX);
        String modifiedInstString = InstructionUtils.constructBinaryInstString(this.instString, opcode, this.input1, op2, this.output);
        FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
        FederatedRequest fr2 = FederationUtils.callInstruction(modifiedInstString, this.output, new CPOperand[]{this.input1, op2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
        mo1.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{fr2});
        out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
        ec.removeVariable(op2.getName());
    }

    private MatrixObject setOutputFedMapping(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(fedMapObj.getDataCharacteristics());
        out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
        return out;
    }
}

