/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.aggregation;

import java.time.Instant;
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Locale;
import org.opensearch.sql.data.model.ExprDateValue;
import org.opensearch.sql.data.model.ExprDatetimeValue;
import org.opensearch.sql.data.model.ExprDoubleValue;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprTimeValue;
import org.opensearch.sql.data.model.ExprTimestampValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.aggregation.AggregationState;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.utils.ExpressionUtils;

public class AvgAggregator
extends Aggregator<AvgState> {
    private final ExprCoreType dataType;

    public AvgAggregator(List<Expression> arguments, ExprCoreType returnType) {
        super(BuiltinFunctionName.AVG.getName(), arguments, returnType);
        this.dataType = returnType;
    }

    @Override
    public AvgState create() {
        switch (this.dataType) {
            case DATE: {
                return new DateAvgState();
            }
            case DATETIME: {
                return new DateTimeAvgState();
            }
            case TIMESTAMP: {
                return new TimestampAvgState();
            }
            case TIME: {
                return new TimeAvgState();
            }
            case DOUBLE: {
                return new DoubleAvgState();
            }
        }
        throw new IllegalArgumentException(String.format("avg aggregation over %s type is not supported", this.dataType));
    }

    @Override
    protected AvgState iterate(ExprValue value, AvgState state) {
        return state.iterate(value);
    }

    public String toString() {
        return String.format(Locale.ROOT, "avg(%s)", ExpressionUtils.format(this.getArguments()));
    }

    protected static class TimeAvgState
    extends AvgState {
        protected TimeAvgState() {
        }

        @Override
        public ExprValue result() {
            if (0 == this.count.integerValue()) {
                return ExprNullValue.of();
            }
            return new ExprTimeValue(LocalTime.MIN.plus(DSL.divide(DSL.literal(this.total), DSL.literal(this.count)).valueOf().longValue(), ChronoUnit.MILLIS));
        }

        @Override
        protected AvgState iterate(ExprValue value) {
            this.total = DSL.add(DSL.literal(this.total), DSL.literal(ChronoUnit.MILLIS.between(LocalTime.MIN, value.timeValue()))).valueOf();
            return super.iterate(value);
        }
    }

    protected static class TimestampAvgState
    extends AvgState {
        protected TimestampAvgState() {
        }

        @Override
        public ExprValue result() {
            if (0 == this.count.integerValue()) {
                return ExprNullValue.of();
            }
            return new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(this.total), DSL.literal(this.count)).valueOf().longValue()));
        }

        @Override
        protected AvgState iterate(ExprValue value) {
            this.total = DSL.add(DSL.literal(this.total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf();
            return super.iterate(value);
        }
    }

    protected static class DateTimeAvgState
    extends AvgState {
        protected DateTimeAvgState() {
        }

        @Override
        public ExprValue result() {
            if (0 == this.count.integerValue()) {
                return ExprNullValue.of();
            }
            return new ExprDatetimeValue(new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(this.total), DSL.literal(this.count)).valueOf().longValue())).datetimeValue());
        }

        @Override
        protected AvgState iterate(ExprValue value) {
            this.total = DSL.add(DSL.literal(this.total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf();
            return super.iterate(value);
        }
    }

    protected static class DateAvgState
    extends AvgState {
        protected DateAvgState() {
        }

        @Override
        public ExprValue result() {
            if (0 == this.count.integerValue()) {
                return ExprNullValue.of();
            }
            return new ExprDateValue(new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(this.total), DSL.literal(this.count)).valueOf().longValue())).dateValue());
        }

        @Override
        protected AvgState iterate(ExprValue value) {
            this.total = DSL.add(DSL.literal(this.total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf();
            return super.iterate(value);
        }
    }

    protected static class DoubleAvgState
    extends AvgState {
        protected DoubleAvgState() {
        }

        @Override
        public ExprValue result() {
            if (0 == this.count.integerValue()) {
                return ExprNullValue.of();
            }
            return DSL.divide(DSL.literal(this.total), DSL.literal(this.count)).valueOf();
        }

        @Override
        protected AvgState iterate(ExprValue value) {
            this.total = DSL.add(DSL.literal(this.total), DSL.literal(value)).valueOf();
            return super.iterate(value);
        }
    }

    protected static abstract class AvgState
    implements AggregationState {
        protected ExprValue count = new ExprIntegerValue(0);
        protected ExprValue total = new ExprDoubleValue(0.0);

        AvgState() {
        }

        @Override
        public abstract ExprValue result();

        protected AvgState iterate(ExprValue value) {
            this.count = DSL.add(DSL.literal(this.count), DSL.literal(1)).valueOf();
            return this;
        }
    }
}

