/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.parkservices;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.parkservices.ForecastDescriptor;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.calibration.Calibration;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.RangeVector;
import java.util.Arrays;
import java.util.function.BiFunction;
import lombok.Generated;

public class ErrorHandler {
    public static int MAX_ERROR_HORIZON = 1024;
    int sequenceIndex;
    double percentile;
    int forecastHorizon;
    int errorHorizon;
    protected RangeVector[] pastForecasts;
    protected float[][] actuals;
    RangeVector errorDistribution;
    DiVector errorRMSE;
    float[] errorMean;
    float[] intervalPrecision;
    float[] lastDeviations;
    RangeVector multipliers;
    RangeVector adders;

    public ErrorHandler(RCFCaster.Builder builder) {
        CommonUtils.checkArgument((builder.forecastHorizon > 0 ? 1 : 0) != 0, (String)"has to be positive");
        CommonUtils.checkArgument((builder.errorHorizon >= builder.forecastHorizon ? 1 : 0) != 0, (String)"intervalPrecision horizon should be at least as large as forecast horizon");
        CommonUtils.checkArgument((builder.errorHorizon <= MAX_ERROR_HORIZON ? 1 : 0) != 0, (String)"reduce error horizon of change MAX");
        this.forecastHorizon = builder.forecastHorizon;
        this.errorHorizon = builder.errorHorizon;
        int inputLength = builder.dimensions / builder.shingleSize;
        int length = inputLength * this.forecastHorizon;
        this.percentile = builder.percentile;
        this.pastForecasts = new RangeVector[this.errorHorizon + this.forecastHorizon];
        for (int i = 0; i < this.errorHorizon + this.forecastHorizon; ++i) {
            this.pastForecasts[i] = new RangeVector(length);
        }
        this.actuals = new float[this.errorHorizon + this.forecastHorizon][inputLength];
        this.sequenceIndex = 0;
        this.errorMean = new float[length];
        this.errorRMSE = new DiVector(length);
        this.lastDeviations = new float[inputLength];
        this.multipliers = new RangeVector(length);
        Arrays.fill(this.multipliers.upper, 1.0f);
        Arrays.fill(this.multipliers.values, 1.0f);
        Arrays.fill(this.multipliers.lower, 1.0f);
        this.adders = new RangeVector(length);
        this.intervalPrecision = new float[length];
        this.errorDistribution = new RangeVector(length);
    }

    public ErrorHandler(int errorHorizon, int forecastHorizon, int sequenceIndex, double percentile, int inputLength, float[] actualsFlattened, float[] pastForecastsFlattened, float[] lastDeviations, float[] auxilliary) {
        CommonUtils.checkArgument((forecastHorizon > 0 ? 1 : 0) != 0, (String)" incorrect forecast horizon");
        CommonUtils.checkArgument((errorHorizon >= forecastHorizon ? 1 : 0) != 0, (String)"incorrect error horizon");
        CommonUtils.checkArgument((actualsFlattened != null || pastForecastsFlattened == null ? 1 : 0) != 0, (String)" actuals and forecasts are a mismatch");
        CommonUtils.checkArgument((inputLength > 0 ? 1 : 0) != 0, (String)"incorrect parameters");
        CommonUtils.checkArgument((sequenceIndex >= 0 ? 1 : 0) != 0, (String)"cannot be negative");
        CommonUtils.checkArgument((Math.abs(percentile - 0.25) < 0.24 ? 1 : 0) != 0, (String)"has to be between (0,0.5) ");
        this.sequenceIndex = sequenceIndex;
        this.errorHorizon = errorHorizon;
        this.percentile = percentile;
        this.forecastHorizon = forecastHorizon;
        int currentLength = actualsFlattened == null ? 0 : actualsFlattened.length;
        CommonUtils.checkArgument((currentLength % inputLength == 0 ? 1 : 0) != 0, (String)"actuals array is incorrect");
        int forecastLength = pastForecastsFlattened == null ? 0 : pastForecastsFlattened.length;
        int arrayLength = Math.max(forecastHorizon + errorHorizon, currentLength / inputLength);
        this.pastForecasts = new RangeVector[arrayLength];
        this.actuals = new float[arrayLength][inputLength];
        int length = forecastHorizon * inputLength;
        CommonUtils.checkArgument((forecastLength == currentLength * 3 * forecastHorizon ? 1 : 0) != 0, (String)"misaligned forecasts");
        CommonUtils.checkArgument((lastDeviations.length >= inputLength ? 1 : 0) != 0, (String)"incorrect length");
        this.lastDeviations = Arrays.copyOf(lastDeviations, lastDeviations.length);
        this.errorMean = new float[length];
        this.errorRMSE = new DiVector(length);
        this.intervalPrecision = new float[length];
        this.adders = new RangeVector(length);
        this.multipliers = new RangeVector(length);
        this.errorDistribution = new RangeVector(length);
        if (pastForecastsFlattened != null) {
            for (int i = 0; i < arrayLength; ++i) {
                float[] values = Arrays.copyOfRange(pastForecastsFlattened, i * 3 * length, (i * 3 + 1) * length);
                float[] upper = Arrays.copyOfRange(pastForecastsFlattened, (i * 3 + 1) * length, (i * 3 + 2) * length);
                float[] lower = Arrays.copyOfRange(pastForecastsFlattened, (i * 3 + 2) * length, (i * 3 + 3) * length);
                this.pastForecasts[i] = new RangeVector(values, upper, lower);
                System.arraycopy(actualsFlattened, i * inputLength, this.actuals[i], 0, inputLength);
            }
            this.recomputeErrors();
        }
    }

    public void updateActuals(double[] input, double[] deviations) {
        int arrayLength = this.pastForecasts.length;
        int inputLength = input.length;
        if (this.sequenceIndex > 0) {
            int inputIndex = (this.sequenceIndex + arrayLength - 1) % arrayLength;
            for (int i = 0; i < inputLength; ++i) {
                this.actuals[inputIndex][i] = (float)input[i];
            }
        }
        ++this.sequenceIndex;
        this.recomputeErrors();
        this.lastDeviations = CommonUtils.toFloatArray((double[])deviations);
    }

    public void augmentDescriptor(ForecastDescriptor descriptor) {
        descriptor.setErrorMean(this.errorMean);
        descriptor.setErrorRMSE(this.errorRMSE);
        descriptor.setObservedErrorDistribution(this.errorDistribution);
        descriptor.setIntervalPrecision(this.intervalPrecision);
    }

    public void updateForecasts(RangeVector vector) {
        int arrayLength = this.pastForecasts.length;
        int storedForecastIndex = (this.sequenceIndex + arrayLength - 1) % arrayLength;
        int length = this.pastForecasts[0].values.length;
        System.arraycopy(vector.values, 0, this.pastForecasts[storedForecastIndex].values, 0, length);
        System.arraycopy(vector.upper, 0, this.pastForecasts[storedForecastIndex].upper, 0, length);
        System.arraycopy(vector.lower, 0, this.pastForecasts[storedForecastIndex].lower, 0, length);
    }

    public RangeVector getErrorDistribution() {
        return new RangeVector(this.errorDistribution);
    }

    public float[] getErrorMean() {
        return Arrays.copyOf(this.errorMean, this.errorMean.length);
    }

    public DiVector getErrorRMSE() {
        return new DiVector(this.errorRMSE);
    }

    public float[] getIntervalPrecision() {
        return Arrays.copyOf(this.intervalPrecision, this.intervalPrecision.length);
    }

    public RangeVector getMultipliers() {
        return new RangeVector(this.multipliers);
    }

    public RangeVector getAdders() {
        return new RangeVector(this.adders);
    }

    protected double[] getErrorVector(int len, int leadtime, int inputCoordinate, int position, BiFunction<Float, Float, Float> error) {
        int arrayLength = this.pastForecasts.length;
        int errorIndex = (this.sequenceIndex - 1 + arrayLength) % arrayLength;
        double[] copy = new double[len];
        for (int k = 0; k < len; ++k) {
            int pastIndex = (errorIndex - leadtime - k + arrayLength) % arrayLength;
            int index = (errorIndex - k - 1 + arrayLength) % arrayLength;
            copy[k] = error.apply(Float.valueOf(this.actuals[index][inputCoordinate]), Float.valueOf(this.pastForecasts[pastIndex].values[position])).floatValue();
        }
        return copy;
    }

    int length(int sequenceIndex, int errorHorizon, int index) {
        return sequenceIndex > errorHorizon + index + 1 ? errorHorizon : sequenceIndex - index - 1;
    }

    protected void recomputeErrors() {
        int inputLength = this.actuals[0].length;
        int arrayLength = this.pastForecasts.length;
        int inputIndex = (this.sequenceIndex - 2 + arrayLength) % arrayLength;
        double[] medianError = new double[this.errorHorizon];
        Arrays.fill(this.intervalPrecision, 0.0f);
        for (int i = 0; i < this.forecastHorizon; ++i) {
            int len = this.length(this.sequenceIndex, this.errorHorizon, i);
            for (int j = 0; j < inputLength; ++j) {
                int pos = i * inputLength + j;
                if (len > 0) {
                    double positiveSum = 0.0;
                    int positiveCount = 0;
                    double negativeSum = 0.0;
                    double positiveSqSum = 0.0;
                    double negativeSqSum = 0.0;
                    for (int k = 0; k < len; ++k) {
                        double error;
                        int pastIndex = (inputIndex - i - k + arrayLength) % arrayLength;
                        int index = (inputIndex - k + arrayLength) % arrayLength;
                        medianError[k] = error = (double)(this.actuals[index][j] - this.pastForecasts[pastIndex].values[pos]);
                        int n = pos;
                        this.intervalPrecision[n] = this.intervalPrecision[n] + (this.pastForecasts[pastIndex].upper[pos] >= this.actuals[index][j] && this.actuals[index][j] >= this.pastForecasts[pastIndex].lower[pos] ? 1.0f : 0.0f);
                        if (error >= 0.0) {
                            positiveSum += error;
                            positiveSqSum += error * error;
                            ++positiveCount;
                            continue;
                        }
                        negativeSum += error;
                        negativeSqSum += error * error;
                    }
                    this.errorMean[pos] = (float)(positiveSum + negativeSum) / (float)len;
                    this.errorRMSE.high[pos] = positiveCount > 0 ? Math.sqrt(positiveSqSum / (double)positiveCount) : 0.0;
                    double d = this.errorRMSE.low[pos] = positiveCount < len ? -Math.sqrt(negativeSqSum / (double)(len - positiveCount)) : 0.0;
                    if ((double)len * this.percentile >= 1.0) {
                        Arrays.sort(medianError, 0, len);
                        this.errorDistribution.values[pos] = this.interpolatedMedian(medianError, len);
                        this.errorDistribution.upper[pos] = this.interpolatedUpperRank(medianError, len, (double)len * this.percentile);
                        this.errorDistribution.lower[pos] = this.interpolatedLowerRank(medianError, (double)len * this.percentile);
                    }
                    this.intervalPrecision[pos] = this.intervalPrecision[pos] / (float)len;
                    continue;
                }
                this.errorMean[pos] = 0.0f;
                this.errorRMSE.low[pos] = 0.0;
                this.errorRMSE.high[pos] = 0.0;
                this.errorDistribution.lower[pos] = 0.0f;
                this.errorDistribution.upper[pos] = 0.0f;
                this.errorDistribution.values[pos] = 0.0f;
                this.adders.values[pos] = 0.0f;
                this.adders.lower[pos] = 0.0f;
                this.adders.upper[pos] = 0.0f;
                this.intervalPrecision[pos] = 0.0f;
            }
        }
    }

    protected void calibrate(Calibration calibration, RangeVector ranges) {
        int inputLength = this.actuals[0].length;
        CommonUtils.checkArgument((inputLength * this.forecastHorizon == ranges.values.length ? 1 : 0) != 0, (String)"mismatched lengths");
        for (int i = 0; i < this.forecastHorizon; ++i) {
            int len = this.length(this.sequenceIndex, this.errorHorizon, i);
            for (int j = 0; j < inputLength; ++j) {
                int pos = i * inputLength + j;
                if (len <= 0 || calibration == Calibration.NONE) continue;
                if ((double)len * this.percentile < 1.0) {
                    double deviation = this.lastDeviations[j];
                    ranges.upper[pos] = Math.max(ranges.upper[pos], ranges.values[pos] + (float)(1.3 * deviation));
                    ranges.lower[pos] = Math.min(ranges.lower[pos], ranges.values[pos] - (float)(1.3 * deviation));
                    continue;
                }
                if (calibration == Calibration.SIMPLE) {
                    this.adjust(pos, ranges, this.errorDistribution);
                }
                if (calibration != Calibration.MINIMAL) continue;
                this.adjustMinimal(pos, ranges, this.errorDistribution);
            }
        }
    }

    protected float interpolatedMedian(double[] ascendingArray, int len) {
        CommonUtils.checkArgument((ascendingArray != null ? 1 : 0) != 0, (String)" cannot be null");
        CommonUtils.checkArgument((ascendingArray.length >= len ? 1 : 0) != 0, (String)"incorrect length parameter");
        float lower = (float)(len % 2 == 0 ? ascendingArray[len / 2 - 1] : (ascendingArray[len / 2] + ascendingArray[len / 2 - 1]) / 2.0);
        float upper = (float)(len % 2 == 0 ? ascendingArray[len / 2] : (ascendingArray[len / 2] + ascendingArray[len / 2 + 1]) / 2.0);
        if (lower <= 0.0f && 0.0f <= upper) {
            return 0.0f;
        }
        return (upper + lower) / 2.0f;
    }

    float interpolatedLowerRank(double[] ascendingArray, double fracRank) {
        int rank = (int)Math.floor(fracRank);
        return (float)(ascendingArray[rank - 1] + (fracRank - (double)rank) * (ascendingArray[rank] - ascendingArray[rank - 1]));
    }

    float interpolatedUpperRank(double[] ascendingArray, int len, double fracRank) {
        int rank = (int)Math.floor(fracRank);
        return (float)(ascendingArray[len - rank] + (fracRank - (double)rank) * (ascendingArray[len - rank - 1] - ascendingArray[len - rank]));
    }

    void adjust(int pos, RangeVector rangeVector, RangeVector other) {
        CommonUtils.checkArgument((other.values.length == rangeVector.values.length ? 1 : 0) != 0, (String)" mismatch in lengths");
        CommonUtils.checkArgument((pos >= 0 ? 1 : 0) != 0, (String)" cannot be negative");
        CommonUtils.checkArgument((pos < other.values.length ? 1 : 0) != 0, (String)" cannot be this large");
        int n = pos;
        rangeVector.values[n] = rangeVector.values[n] + other.values[pos];
        rangeVector.upper[pos] = Math.max(rangeVector.values[pos], rangeVector.upper[pos] + other.upper[pos]);
        rangeVector.lower[pos] = Math.min(rangeVector.values[pos], rangeVector.lower[pos] + other.lower[pos]);
    }

    void adjustMinimal(int pos, RangeVector rangeVector, RangeVector other) {
        CommonUtils.checkArgument((other.values.length == rangeVector.values.length ? 1 : 0) != 0, (String)" mismatch in lengths");
        CommonUtils.checkArgument((pos >= 0 ? 1 : 0) != 0, (String)" cannot be negative");
        CommonUtils.checkArgument((pos < other.values.length ? 1 : 0) != 0, (String)" cannot be this large");
        float oldVal = rangeVector.values[pos];
        int n = pos;
        rangeVector.values[n] = rangeVector.values[n] + other.values[pos];
        rangeVector.upper[pos] = Math.max(rangeVector.values[pos], oldVal + other.upper[pos]);
        rangeVector.lower[pos] = Math.min(rangeVector.values[pos], oldVal + other.lower[pos]);
    }

    @Generated
    public int getSequenceIndex() {
        return this.sequenceIndex;
    }

    @Generated
    public double getPercentile() {
        return this.percentile;
    }

    @Generated
    public int getForecastHorizon() {
        return this.forecastHorizon;
    }

    @Generated
    public int getErrorHorizon() {
        return this.errorHorizon;
    }

    @Generated
    public RangeVector[] getPastForecasts() {
        return this.pastForecasts;
    }

    @Generated
    public float[][] getActuals() {
        return this.actuals;
    }

    @Generated
    public float[] getLastDeviations() {
        return this.lastDeviations;
    }

    @Generated
    public void setSequenceIndex(int sequenceIndex) {
        this.sequenceIndex = sequenceIndex;
    }

    @Generated
    public void setPercentile(double percentile) {
        this.percentile = percentile;
    }

    @Generated
    public void setForecastHorizon(int forecastHorizon) {
        this.forecastHorizon = forecastHorizon;
    }

    @Generated
    public void setErrorHorizon(int errorHorizon) {
        this.errorHorizon = errorHorizon;
    }

    @Generated
    public void setPastForecasts(RangeVector[] pastForecasts) {
        this.pastForecasts = pastForecasts;
    }

    @Generated
    public void setActuals(float[][] actuals) {
        this.actuals = actuals;
    }

    @Generated
    public void setErrorDistribution(RangeVector errorDistribution) {
        this.errorDistribution = errorDistribution;
    }

    @Generated
    public void setErrorRMSE(DiVector errorRMSE) {
        this.errorRMSE = errorRMSE;
    }

    @Generated
    public void setErrorMean(float[] errorMean) {
        this.errorMean = errorMean;
    }

    @Generated
    public void setIntervalPrecision(float[] intervalPrecision) {
        this.intervalPrecision = intervalPrecision;
    }

    @Generated
    public void setLastDeviations(float[] lastDeviations) {
        this.lastDeviations = lastDeviations;
    }

    @Generated
    public void setMultipliers(RangeVector multipliers) {
        this.multipliers = multipliers;
    }

    @Generated
    public void setAdders(RangeVector adders) {
        this.adders = adders;
    }
}

