/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.privacy.propagation;

import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public enum OperatorType {
    Aggregate,
    NonAggregate;


    public static OperatorType getAggregationType(MMChainCPInstruction inst, ExecutionContext ec) {
        DataCharacteristics inputDataCharacteristics = ec.getDataCharacteristics(inst.getInputs()[0].getName());
        if (inputDataCharacteristics.getRows() == 1L && inputDataCharacteristics.getCols() == 1L) {
            return NonAggregate;
        }
        return Aggregate;
    }

    public static OperatorType getAggregationType(MMTSJCPInstruction inst, ExecutionContext ec) {
        DataCharacteristics inputDataCharacteristics = ec.getDataCharacteristics(inst.getInputs()[0].getName());
        if (inputDataCharacteristics.getRows() == 1L && inst.getMMTSJType() == MMTSJ.MMTSJType.LEFT || inputDataCharacteristics.getCols() == 1L && inst.getMMTSJType() != MMTSJ.MMTSJType.LEFT) {
            return NonAggregate;
        }
        return Aggregate;
    }

    public static OperatorType getAggregationType(AggregateBinaryCPInstruction inst, ExecutionContext ec) {
        DataCharacteristics inputDC = ec.getDataCharacteristics(inst.input1.getName());
        if (inputDC.getCols() == 1L && !inst.transposeLeft || inputDC.getRows() == 1L && inst.transposeLeft) {
            return NonAggregate;
        }
        return Aggregate;
    }
}

