/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.functions;

import org.apache.spark.api.java.function.Function;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.WeightedCell;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class PerformGroupByAggInReducer
implements Function<Iterable<WeightedCell>, WeightedCell> {
    private static final long serialVersionUID = 8160556441153227417L;
    Operator op;

    public PerformGroupByAggInReducer(Operator op) {
        this.op = op;
    }

    public WeightedCell call(Iterable<WeightedCell> kv) throws Exception {
        WeightedCell outCell = new WeightedCell();
        CM_COV_Object cmObj = new CM_COV_Object();
        if (this.op instanceof CMOperator) {
            cmObj.reset();
            CM lcmFn = CM.getCMFnObject(((CMOperator)this.op).aggOpType);
            if (((CMOperator)this.op).isPartialAggregateOperator()) {
                throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInCombiner");
            }
            for (WeightedCell value : kv) {
                lcmFn.execute(cmObj, value.getValue(), value.getWeight());
            }
            outCell.setValue(cmObj.getRequiredResult(this.op));
            outCell.setWeight(1.0);
        } else if (this.op instanceof AggregateOperator) {
            AggregateOperator aggop = (AggregateOperator)this.op;
            if (aggop.existsCorrection()) {
                KahanObject buffer = new KahanObject(aggop.initialValue, 0.0);
                KahanPlus.getKahanPlusFnObject();
                for (WeightedCell value : kv) {
                    aggop.increOp.fn.execute((Data)buffer, value.getValue() * value.getWeight());
                }
                outCell.setValue(buffer._sum);
                outCell.setWeight(1.0);
            } else {
                double v = aggop.initialValue;
                for (WeightedCell value : kv) {
                    v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
                }
                outCell.setValue(v);
                outCell.setWeight(1.0);
            }
        } else {
            throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + this.op);
        }
        return outCell;
    }
}

