/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.time.Instant;
import java.util.Iterator;
import java.util.LongSummaryStatistics;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.LongSupplier;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.core.ml.utils.Intervals;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;

public class PyTorchResultProcessor {
    private static final Logger logger = LogManager.getLogger(PyTorchResultProcessor.class);
    static long REPORTING_PERIOD_MS = TimeValue.timeValueMinutes((long)1L).millis();
    private final ConcurrentMap<String, PendingResult> pendingResults = new ConcurrentHashMap<String, PendingResult>();
    private final String modelId;
    private final Consumer<ThreadSettings> threadSettingsConsumer;
    private volatile boolean isStopping;
    private final LongSummaryStatistics timingStats;
    private final LongSummaryStatistics timingStatsExcludingCacheHits;
    private int errorCount;
    private long cacheHitCount;
    private long peakThroughput;
    private LongSummaryStatistics lastPeriodSummaryStats;
    private long lastPeriodCacheHitCount;
    private RecentStats lastPeriodStats;
    private long currentPeriodEndTimeMs;
    private long lastResultTimeMs;
    private final long startTime;
    private final LongSupplier currentTimeMsSupplier;
    private final CountDownLatch processorCompletionLatch = new CountDownLatch(1);

    public PyTorchResultProcessor(String modelId, Consumer<ThreadSettings> threadSettingsConsumer) {
        this(modelId, threadSettingsConsumer, System::currentTimeMillis);
    }

    PyTorchResultProcessor(String modelId, Consumer<ThreadSettings> threadSettingsConsumer, LongSupplier currentTimeSupplier) {
        this.modelId = Objects.requireNonNull(modelId);
        this.timingStats = new LongSummaryStatistics();
        this.timingStatsExcludingCacheHits = new LongSummaryStatistics();
        this.lastPeriodSummaryStats = new LongSummaryStatistics();
        this.threadSettingsConsumer = Objects.requireNonNull(threadSettingsConsumer);
        this.currentTimeMsSupplier = currentTimeSupplier;
        this.startTime = currentTimeSupplier.getAsLong();
        this.currentPeriodEndTimeMs = this.startTime + REPORTING_PERIOD_MS;
    }

    public void registerRequest(String requestId, ActionListener<PyTorchResult> listener) {
        this.pendingResults.computeIfAbsent(requestId, k -> new PendingResult(listener));
    }

    public void ignoreResponseWithoutNotifying(String requestId) {
        this.pendingResults.remove(requestId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void process(PyTorchProcess process) {
        try {
            Iterator<PyTorchResult> iterator = process.readResults();
            while (iterator.hasNext()) {
                PyTorchResult result = iterator.next();
                if (result.inferenceResult() != null) {
                    this.processInferenceResult(result);
                    continue;
                }
                if (result.threadSettings() != null) {
                    this.threadSettingsConsumer.accept(result.threadSettings());
                    this.processThreadSettings(result);
                    continue;
                }
                if (result.ackResult() != null) {
                    this.processAcknowledgement(result);
                    continue;
                }
                if (result.errorResult() != null) {
                    this.processErrorResult(result);
                    continue;
                }
                this.handleUnknownResultType(result);
            }
        }
        catch (Exception e) {
            if (!this.isStopping) {
                logger.error(() -> "[" + this.modelId + "] Error processing results", (Throwable)e);
            }
            ErrorResult errorResult = new ErrorResult((String)(this.isStopping ? "inference canceled as process is stopping" : "inference native process died unexpectedly with failure [" + e.getMessage() + "]"));
            this.notifyAndClearPendingResults(errorResult);
        }
        finally {
            this.notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping"));
            this.processorCompletionLatch.countDown();
        }
        logger.debug(() -> "[" + this.modelId + "] Results processing finished");
    }

    private void notifyAndClearPendingResults(ErrorResult errorResult) {
        if (this.pendingResults.size() > 0) {
            logger.warn(Strings.format((String)"[%s] clearing [%d] requests pending results", (Object[])new Object[]{this.modelId, this.pendingResults.size()}));
        }
        this.pendingResults.forEach((id, pendingResult) -> pendingResult.listener.onResponse((Object)new PyTorchResult((String)id, null, null, null, null, null, errorResult)));
        this.pendingResults.clear();
    }

    void processInferenceResult(PyTorchResult result) {
        PyTorchInferenceResult inferenceResult = result.inferenceResult();
        assert (inferenceResult != null);
        logger.debug(() -> Strings.format((String)"[%s] Parsed inference result with id [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        PendingResult pendingResult = (PendingResult)this.pendingResults.remove(result.requestId());
        if (pendingResult == null) {
            logger.debug(() -> Strings.format((String)"[%s] no pending result for inference [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        } else {
            pendingResult.listener.onResponse((Object)result);
        }
    }

    void processThreadSettings(PyTorchResult result) {
        ThreadSettings threadSettings = result.threadSettings();
        assert (threadSettings != null);
        logger.debug(() -> Strings.format((String)"[%s] Parsed thread settings result with id [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        PendingResult pendingResult = (PendingResult)this.pendingResults.remove(result.requestId());
        if (pendingResult == null) {
            logger.debug(() -> Strings.format((String)"[%s] no pending result for thread settings [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        } else {
            pendingResult.listener.onResponse((Object)result);
        }
    }

    void processAcknowledgement(PyTorchResult result) {
        AckResult ack = result.ackResult();
        assert (ack != null);
        logger.debug(() -> Strings.format((String)"[%s] Parsed ack result with id [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        PendingResult pendingResult = (PendingResult)this.pendingResults.remove(result.requestId());
        if (pendingResult == null) {
            logger.debug(() -> Strings.format((String)"[%s] no pending result for ack [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        } else {
            pendingResult.listener.onResponse((Object)result);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void processErrorResult(PyTorchResult result) {
        ErrorResult errorResult = result.errorResult();
        assert (errorResult != null);
        PyTorchResultProcessor pyTorchResultProcessor = this;
        synchronized (pyTorchResultProcessor) {
            ++this.errorCount;
        }
        logger.debug(() -> Strings.format((String)"[%s] Parsed error with id [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        PendingResult pendingResult = (PendingResult)this.pendingResults.remove(result.requestId());
        if (pendingResult == null) {
            logger.debug(() -> Strings.format((String)"[%s] no pending result for error [%s]", (Object[])new Object[]{this.modelId, result.requestId()}));
        } else {
            pendingResult.listener.onResponse((Object)result);
        }
    }

    void handleUnknownResultType(PyTorchResult result) {
        if (result.requestId() != null) {
            PendingResult pendingResult = (PendingResult)this.pendingResults.remove(result.requestId());
            if (pendingResult == null) {
                logger.error(() -> Strings.format((String)"[%s] no pending result listener for unknown result type [%s]", (Object[])new Object[]{this.modelId, result}));
            } else {
                String msg = Strings.format((String)"[%s] pending result listener cannot handle unknown result type [%s]", (Object[])new Object[]{this.modelId, result});
                logger.error(msg);
                ErrorResult errorResult = new ErrorResult(msg);
                pendingResult.listener.onResponse((Object)new PyTorchResult(result.requestId(), null, null, null, null, null, errorResult));
            }
        } else {
            logger.error(() -> Strings.format((String)"[%s] cannot process unknown result type [%s]", (Object[])new Object[]{this.modelId, result}));
        }
    }

    public synchronized ResultStats getResultStats() {
        long currentMs = this.currentTimeMsSupplier.getAsLong();
        long currentPeriodStartTimeMs = this.startTime + Intervals.alignToFloor((long)(currentMs - this.startTime), (long)REPORTING_PERIOD_MS);
        RecentStats rs = null;
        if (this.lastResultTimeMs >= currentPeriodStartTimeMs) {
            rs = this.lastPeriodStats;
        } else if (this.lastResultTimeMs >= currentPeriodStartTimeMs - REPORTING_PERIOD_MS) {
            rs = new RecentStats(this.lastPeriodSummaryStats.getCount(), this.lastPeriodSummaryStats.getAverage(), this.lastPeriodCacheHitCount);
            this.peakThroughput = Math.max(this.peakThroughput, this.lastPeriodSummaryStats.getCount());
        }
        if (rs == null) {
            rs = new RecentStats(0L, null, 0L);
        }
        return new ResultStats(PyTorchResultProcessor.cloneSummaryStats(this.timingStats), PyTorchResultProcessor.cloneSummaryStats(this.timingStatsExcludingCacheHits), this.errorCount, this.cacheHitCount, this.pendingResults.size(), this.lastResultTimeMs > 0L ? Instant.ofEpochMilli(this.lastResultTimeMs) : null, this.peakThroughput, rs);
    }

    private static LongSummaryStatistics cloneSummaryStats(LongSummaryStatistics stats) {
        return new LongSummaryStatistics(stats.getCount(), stats.getMin(), stats.getMax(), stats.getSum());
    }

    public synchronized void updateStats(PyTorchResult result) {
        Long timeMs = result.timeMs();
        if (timeMs == null) {
            assert (false) : "time_ms should be set for an inference result";
            timeMs = 0L;
        }
        boolean isCacheHit = Boolean.TRUE.equals(result.isCacheHit());
        this.timingStats.accept(timeMs);
        this.lastResultTimeMs = this.currentTimeMsSupplier.getAsLong();
        if (this.lastResultTimeMs > this.currentPeriodEndTimeMs) {
            this.peakThroughput = Math.max(this.peakThroughput, this.lastPeriodSummaryStats.getCount());
            this.lastPeriodStats = this.lastResultTimeMs > this.currentPeriodEndTimeMs + REPORTING_PERIOD_MS ? null : new RecentStats(this.lastPeriodSummaryStats.getCount(), this.lastPeriodSummaryStats.getAverage(), this.lastPeriodCacheHitCount);
            this.lastPeriodCacheHitCount = 0L;
            this.lastPeriodSummaryStats = new LongSummaryStatistics();
            this.lastPeriodSummaryStats.accept(timeMs);
            this.currentPeriodEndTimeMs = this.startTime + Intervals.alignToCeil((long)(this.lastResultTimeMs - this.startTime), (long)REPORTING_PERIOD_MS);
        } else {
            this.lastPeriodSummaryStats.accept(timeMs);
        }
        if (isCacheHit) {
            ++this.cacheHitCount;
            ++this.lastPeriodCacheHitCount;
        } else {
            this.timingStatsExcludingCacheHits.accept(timeMs);
        }
    }

    public void stop() {
        this.isStopping = true;
    }

    public void awaitCompletion(long timeout, TimeUnit unit) throws TimeoutException {
        try {
            if (!this.processorCompletionLatch.await(timeout, unit)) {
                throw new TimeoutException(Strings.format((String)"Timed out waiting for pytorch results processor to complete for model id %s", (Object[])new Object[]{this.modelId}));
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            logger.info(Strings.format((String)"[%s] Interrupted waiting for pytorch results processor to complete", (Object[])new Object[]{this.modelId}));
        }
    }

    public static class PendingResult {
        public final ActionListener<PyTorchResult> listener;

        public PendingResult(ActionListener<PyTorchResult> listener) {
            this.listener = Objects.requireNonNull(listener);
        }
    }

    public record RecentStats(long requestsProcessed, Double avgInferenceTime, long cacheHitCount) {
    }

    public record ResultStats(LongSummaryStatistics timingStats, LongSummaryStatistics timingStatsExcludingCacheHits, int errorCount, long cacheHitCount, int numberOfPendingResults, Instant lastUsed, long peakThroughput, RecentStats recentStats) {
    }
}

