/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Weight;
import org.opensearch.common.Nullable;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.profile.KNNProfileUtil;
import org.opensearch.knn.profile.query.KNNQueryTimingType;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.query.QueryProfiler;

public class RescoreKNNVectorQuery
extends Query {
    @Generated
    private static final Logger log = LogManager.getLogger(RescoreKNNVectorQuery.class);
    private final Query innerQuery;
    private final String field;
    private final int k;
    private final float[] queryVector;
    private final int shardId;
    private final ExactSearcher exactSearcher;

    public RescoreKNNVectorQuery(Query innerQuery, String field, int k, float[] queryVector, int shardId) {
        this.innerQuery = innerQuery;
        this.field = field;
        this.k = k;
        this.queryVector = queryVector;
        this.shardId = shardId;
        this.exactSearcher = new ExactSearcher(ModelDao.OpenSearchKNNModelDao.getInstance());
    }

    @VisibleForTesting
    public RescoreKNNVectorQuery(Query innerQuery, String field, int k, float[] queryVector, int shardId, ExactSearcher searcher) {
        this.innerQuery = innerQuery;
        this.field = field;
        this.k = k;
        this.queryVector = queryVector;
        this.shardId = shardId;
        this.exactSearcher = searcher;
    }

    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        Query rewrittenInnerQuery = searcher.rewrite(this.innerQuery);
        Weight weight = searcher.createWeight(rewrittenInnerQuery, scoreMode, boost);
        StopWatch stopWatch = this.startStopWatch();
        TopDocs[] perLeafResults = this.doRescore(searcher, weight);
        this.stopStopWatchAndLog(stopWatch);
        TopDocs topK = TopDocs.merge((int)this.k, (TopDocs[])perLeafResults);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery().createWeight(searcher, scoreMode, boost);
        }
        return QueryUtils.getInstance().createDocAndScoreQuery(searcher.getIndexReader(), topK).createWeight(searcher, scoreMode, boost);
    }

    private TopDocs[] doRescore(IndexSearcher indexSearcher, Weight weight) throws IOException {
        List leafReaderContexts = indexSearcher.getIndexReader().leaves();
        ArrayList<Callable<TopDocs>> rescoreTasks = new ArrayList<Callable<TopDocs>>(leafReaderContexts.size());
        QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher);
        ContextualProfileBreakdown profile = profiler != null ? (ContextualProfileBreakdown)profiler.getProfileBreakdown((Query)this) : null;
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            rescoreTasks.add(() -> this.searchLeaf(this.exactSearcher, weight, this.k, leafReaderContext, profile));
        }
        return (TopDocs[])indexSearcher.getTaskExecutor().invokeAll(rescoreTasks).toArray(TopDocs[]::new);
    }

    private TopDocs searchLeaf(ExactSearcher searcher, Weight weight, int k, LeafReaderContext leafReaderContext, ContextualProfileBreakdown profile) throws IOException {
        Scorer scorer = weight.scorer(leafReaderContext);
        if (scorer == null) {
            return TopDocsCollector.EMPTY_TOPDOCS;
        }
        DocIdSetIterator iterator = scorer.iterator();
        ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocsIterator(iterator).numberOfMatchedDocs(iterator.cost()).useQuantizedVectorsForSearch(false).k(k).field(this.field).floatQueryVector(this.queryVector).build();
        TopDocs results = (TopDocs)KNNProfileUtil.profileBreakdown(profile, leafReaderContext, KNNQueryTimingType.EXACT_SEARCH, () -> searcher.searchLeaf(leafReaderContext, exactSearcherContext));
        if (leafReaderContext.docBase > 0) {
            for (ScoreDoc scoreDoc : results.scoreDocs) {
                scoreDoc.doc += leafReaderContext.docBase;
            }
        }
        return results;
    }

    private StopWatch startStopWatch() {
        if (log.isDebugEnabled()) {
            return new StopWatch().start();
        }
        return null;
    }

    private void stopStopWatchAndLog(@Nullable StopWatch stopWatch) {
        if (log.isDebugEnabled() && stopWatch != null) {
            stopWatch.stop();
            log.debug("[{}] shard: [{}], field: [{}], time in nanos:[{}] ", (Object)((Object)((Object)this)).getClass().getSimpleName(), (Object)this.shardId, (Object)this.field, (Object)stopWatch.totalTime().nanos());
        }
    }

    public String toString(String field) {
        return ((Object)((Object)this)).getClass().getSimpleName() + "innerQuery=" + String.valueOf(this.innerQuery) + "field=" + field + ", vector=" + String.valueOf(this.queryVector) + ", k=" + this.k + ", shardId=" + this.shardId + "]";
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    public boolean equals(Object obj) {
        if (!this.sameClassAs(obj)) {
            return false;
        }
        RescoreKNNVectorQuery other = (RescoreKNNVectorQuery)((Object)obj);
        return Objects.equals(this.innerQuery, other.innerQuery) && Objects.equals(this.queryVector, other.queryVector) && Objects.equals(this.field, other.field) && this.k == other.k && this.shardId == other.shardId;
    }

    public int hashCode() {
        return Objects.hash(this.innerQuery, this.queryVector, this.field, this.k, this.shardId);
    }
}

