/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.action.IndicesRequest;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.common.SetOnce;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentLocation;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryCoordinatorContext;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.WithFieldName;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.mapper.SemanticFieldMapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.EmbeddingContentType;
import org.opensearch.neuralsearch.processor.InferenceRequest;
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.query.AbstractNeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralKNNQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryTwoPhaseInfo;
import org.opensearch.neuralsearch.query.dto.NeuralQueryBuildStage;
import org.opensearch.neuralsearch.query.dto.NeuralQueryTargetFieldConfig;
import org.opensearch.neuralsearch.query.parser.NeuralQueryParser;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.util.NeuralQueryValidationUtil;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.SemanticMappingUtils;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

public class NeuralQueryBuilder
extends AbstractNeuralQueryBuilder<NeuralQueryBuilder>
implements WithFieldName {
    @Generated
    private static final Logger log = LogManager.getLogger(NeuralQueryBuilder.class);
    public static final String NAME = "neural";
    public static final ParseField SEMANTIC_FIELD_SEARCH_ANALYZER_FIELD = new ParseField("semantic_field_search_analyzer", new String[0]);
    public static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image", new String[0]);
    @VisibleForTesting
    static final ParseField K_FIELD = new ParseField("k", new String[0]);
    public static final int DEFAULT_K = 10;
    public static final Set<String> SUPPORTED_TARGET_FIELD_TYPES = Set.of("semantic", "knn_vector");
    private static MLCommonsClientAccessor ML_CLIENT;
    private String embeddingFieldType;
    private String queryImage;
    private Integer k = null;
    private Float maxDistance = null;
    private Float minScore = null;
    private Boolean expandNested;
    @VisibleForTesting
    private Supplier<float[]> vectorSupplier;
    private QueryBuilder queryfilter;
    private Map<String, ?> methodParameters;
    private RescoreContext rescoreContext;
    private Map<String, Supplier<float[]>> modelIdToVectorSupplierMap;
    private Map<String, Supplier<Map<String, Float>>> modelIdToQueryTokensSupplierMap;
    private Map<String, Map<String, Float>> modelIdToTwoPhaseSharedQueryToken;
    private Supplier<Map<String, Map<String, Float>>> modelIdToTwoPhaseSharedQueryTokenSupplier;

    public static void initialize(MLCommonsClientAccessor mlClient) {
        ML_CLIENT = mlClient;
    }

    private static void validateNeuralQueryBuilder(@NonNull NeuralQueryBuilder neuralQueryBuilder, NeuralQueryBuildStage buildStage, Boolean isSemanticField, String embeddingFieldType) {
        Objects.requireNonNull(neuralQueryBuilder, "neuralQueryBuilder is marked non-null but is null");
        List<Object> errors = new ArrayList();
        if (buildStage == null || NeuralQueryBuildStage.FROM_X_CONTENT.equals((Object)buildStage)) {
            if (!MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForSemanticFieldType()) {
                errors = NeuralQueryValidationUtil.validateNeuralQueryForKnn(neuralQueryBuilder, buildStage);
            }
        } else if (NeuralQueryBuildStage.REWRITE.equals((Object)buildStage)) {
            if (isSemanticField == null || !isSemanticField.booleanValue()) {
                errors = NeuralQueryValidationUtil.validateNeuralQueryForKnn(neuralQueryBuilder, buildStage);
            } else if ("knn_vector".equals(embeddingFieldType)) {
                errors = NeuralQueryValidationUtil.validateNeuralQueryForSemanticDense(neuralQueryBuilder);
            } else if ("rank_features".equals(embeddingFieldType)) {
                errors = NeuralQueryValidationUtil.validateNeuralQueryForSemanticSparse(neuralQueryBuilder);
            } else {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported embedding field type: %s", embeddingFieldType));
            }
        }
        if (!errors.isEmpty()) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid neural query: %s", String.join((CharSequence)"; ", errors)));
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public NeuralQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
            this.queryText = in.readOptionalString();
            this.queryImage = in.readOptionalString();
        } else {
            this.queryText = in.readString();
        }
        this.modelId = MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultDenseModelIdSupport() ? in.readOptionalString() : in.readString();
        this.k = MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch() ? in.readOptionalInt() : Integer.valueOf(in.readVInt());
        this.queryfilter = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            this.maxDistance = in.readOptionalFloat();
            this.minScore = in.readOptionalFloat();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName())) {
            this.expandNested = in.readOptionalBoolean();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            this.methodParameters = MethodParametersParser.streamInput((StreamInput)in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        this.rescoreContext = RescoreParser.streamInput((StreamInput)in);
        if (in.getVersion().onOrAfter(MinClusterVersionUtil.MINIMAL_SUPPORTED_VERSION_SEMANTIC_FIELD)) {
            this.vectorSupplier = NeuralQueryParser.vectorSupplierStreamInput(in);
            this.queryTokensMapSupplier = NeuralQueryParser.queryTokensMapSupplierStreamInput(in);
            this.modelIdToVectorSupplierMap = NeuralQueryParser.modelIdToVectorSupplierMapStreamInput(in);
            this.modelIdToQueryTokensSupplierMap = NeuralQueryParser.modelIdToQueryTokensSupplierMapStreamInput(in);
            this.searchAnalyzer = in.readOptionalString();
        }
        if (in.getVersion().onOrAfter(MinClusterVersionUtil.MINIMAL_SUPPORTED_VERSION_SEMANTIC_FIELD_SPARSE_TWO_PHASE)) {
            this.neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo(in);
            this.modelIdToTwoPhaseSharedQueryToken = NeuralQueryParser.modelIdToTwoPhaseSharedQueryTokenStreamInput(in);
            this.modelIdToTwoPhaseSharedQueryTokenSupplier = NeuralQueryParser.modelIdToTwoPhaseSharedQueryTokenSupplierStreamInput(in);
        }
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
            out.writeOptionalString(this.queryText);
            out.writeOptionalString(this.queryImage);
        } else {
            out.writeString(this.queryText);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultDenseModelIdSupport()) {
            out.writeOptionalString(this.modelId);
        } else {
            out.writeString(this.modelId);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            out.writeOptionalInt(this.k);
        } else {
            out.writeVInt(this.k.intValue());
        }
        out.writeOptionalNamedWriteable((NamedWriteable)this.queryfilter);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            out.writeOptionalFloat(this.maxDistance);
            out.writeOptionalFloat(this.minScore);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName())) {
            out.writeOptionalBoolean(this.expandNested);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            MethodParametersParser.streamOutput((StreamOutput)out, this.methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        RescoreParser.streamOutput((StreamOutput)out, (RescoreContext)this.rescoreContext);
        if (out.getVersion().onOrAfter(MinClusterVersionUtil.MINIMAL_SUPPORTED_VERSION_SEMANTIC_FIELD)) {
            NeuralQueryParser.vectorSupplierStreamOutput(out, this.vectorSupplier);
            NeuralQueryParser.queryTokensMapSupplierStreamOutput(out, this.queryTokensMapSupplier);
            NeuralQueryParser.modelIdToVectorSupplierMapStreamOutput(out, this.modelIdToVectorSupplierMap);
            NeuralQueryParser.modelIdToQueryTokensSupplierMapStreamOutput(out, this.modelIdToQueryTokensSupplierMap);
            out.writeOptionalString(this.searchAnalyzer);
        }
        if (out.getVersion().onOrAfter(MinClusterVersionUtil.MINIMAL_SUPPORTED_VERSION_SEMANTIC_FIELD_SPARSE_TWO_PHASE)) {
            this.neuralSparseQueryTwoPhaseInfo.writeTo(out);
            NeuralQueryParser.modelIdToTwoPhaseSharedQueryTokenStreamOutput(out, this.modelIdToTwoPhaseSharedQueryToken);
            NeuralQueryParser.modelIdToTwoPhaseSharedQueryTokenSupplierStreamOutput(out, this.modelIdToTwoPhaseSharedQueryTokenSupplier);
        }
    }

    public QueryBuilder filter(QueryBuilder filterToBeAdded) {
        if (!NeuralQueryBuilder.validateFilterParams((QueryBuilder)filterToBeAdded)) {
            return this;
        }
        this.queryfilter = this.queryfilter == null ? filterToBeAdded : this.queryfilter.filter(filterToBeAdded);
        return this;
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        if (Objects.nonNull(this.queryText)) {
            xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        }
        if (Objects.nonNull(this.queryImage)) {
            xContentBuilder.field(QUERY_IMAGE_FIELD.getPreferredName(), this.queryImage);
        }
        if (Objects.nonNull(this.modelId)) {
            xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        }
        if (Objects.nonNull(this.k)) {
            xContentBuilder.field(K_FIELD.getPreferredName(), this.k);
        }
        if (Objects.nonNull(this.queryfilter)) {
            xContentBuilder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), (ToXContent)this.queryfilter);
        }
        if (Objects.nonNull(this.maxDistance)) {
            xContentBuilder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), this.maxDistance);
        }
        if (Objects.nonNull(this.minScore)) {
            xContentBuilder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), this.minScore);
        }
        if (Objects.nonNull(this.expandNested)) {
            xContentBuilder.field(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName(), this.expandNested);
        }
        if (Objects.nonNull(this.methodParameters)) {
            MethodParametersParser.doXContent((XContentBuilder)xContentBuilder, this.methodParameters);
        }
        if (Objects.nonNull(this.rescoreContext)) {
            RescoreParser.doXContent((XContentBuilder)xContentBuilder, (RescoreContext)this.rescoreContext);
        }
        if (Objects.nonNull(this.queryTokensMapSupplier) && Objects.nonNull(this.queryTokensMapSupplier.get())) {
            xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), this.queryTokensMapSupplier.get());
        }
        if (Objects.nonNull(this.searchAnalyzer)) {
            xContentBuilder.field(SEMANTIC_FIELD_SEARCH_ANALYZER_FIELD.getPreferredName(), this.searchAnalyzer);
        }
        this.printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException {
        EventStatsManager.increment(EventStatName.NEURAL_QUERY_REQUESTS);
        Builder builder = new Builder();
        if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT", new Object[0]);
        }
        parser.nextToken();
        builder.fieldName(parser.currentName());
        parser.nextToken();
        NeuralQueryBuilder.parseQueryParams(parser, builder);
        if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "[neural] query doesn't support multiple fields, found [" + builder.fieldName + "] and [" + parser.currentName() + "]", new Object[0]);
        }
        builder.buildStage(NeuralQueryBuildStage.FROM_X_CONTENT);
        return builder.build();
    }

    private static void parseQueryParams(XContentParser parser, Builder builder) throws IOException {
        XContentParser.Token token;
        String currentFieldName = "";
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token.isValue()) {
                if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.queryText(parser.text());
                    continue;
                }
                if (QUERY_IMAGE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.queryImage(parser.text());
                    continue;
                }
                if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.modelId(parser.text());
                    continue;
                }
                if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.k((Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false));
                    continue;
                }
                if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.queryName(parser.text());
                    continue;
                }
                if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.boost(parser.floatValue());
                    continue;
                }
                if (KNNQueryBuilder.MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.maxDistance(Float.valueOf(parser.floatValue()));
                    continue;
                }
                if (KNNQueryBuilder.MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.minScore(Float.valueOf(parser.floatValue()));
                    continue;
                }
                if (KNNQueryBuilder.EXPAND_NESTED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.expandNested(parser.booleanValue());
                    continue;
                }
                if (SEMANTIC_FIELD_SEARCH_ANALYZER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.searchAnalyzer(parser.text());
                    continue;
                }
                throw NeuralQueryBuilder.getUnsupportedFieldException(parser.getTokenLocation(), currentFieldName);
            }
            if (token == XContentParser.Token.START_OBJECT) {
                if (KNNQueryBuilder.FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.filter(NeuralQueryBuilder.parseInnerQueryBuilder((XContentParser)parser));
                    continue;
                }
                if (KNNQueryBuilder.METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.methodParameters(MethodParametersParser.fromXContent((XContentParser)parser));
                    continue;
                }
                if (KNNQueryBuilder.RESCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    builder.rescoreContext(RescoreParser.fromXContent((XContentParser)parser));
                    continue;
                }
                if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForSemanticFieldType()) {
                        Map queryTokens = parser.map(HashMap::new, XContentParser::floatValue);
                        builder.queryTokensMapSupplier(() -> queryTokens);
                        continue;
                    }
                    throw NeuralQueryBuilder.getUnsupportedFieldException(parser.getTokenLocation(), currentFieldName);
                }
                throw NeuralQueryBuilder.getUnsupportedFieldException(parser.getTokenLocation(), currentFieldName);
            }
            throw new ParsingException(parser.getTokenLocation(), "[neural] unknown token [" + String.valueOf(token) + "] after [" + currentFieldName + "]", new Object[0]);
        }
    }

    private static ParsingException getUnsupportedFieldException(XContentLocation tokenLocation, String currentFieldName) {
        return new ParsingException(tokenLocation, "[neural] query does not support [" + currentFieldName + "]", new Object[0]);
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (!MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForSemanticFieldType()) {
            return this.rewriteQueryAgainstKnnField(queryRewriteContext);
        }
        QueryCoordinatorContext coordinatorContext = queryRewriteContext.convertToCoordinatorContext();
        if (coordinatorContext != null) {
            return this.rewriteQueryWithCoordinatorContext(coordinatorContext);
        }
        QueryShardContext queryShardContext = queryRewriteContext.convertToShardContext();
        if (queryShardContext != null) {
            return this.rewriteQueryWithQueryShardContext(queryShardContext);
        }
        return this;
    }

    private QueryBuilder rewriteQueryWithCoordinatorContext(@NonNull QueryCoordinatorContext queryRewriteContext) {
        boolean canSkipInference;
        boolean canRewriteSingleTargetIndex;
        Objects.requireNonNull(queryRewriteContext, "queryRewriteContext is marked non-null but is null");
        IndicesRequest searchRequest = queryRewriteContext.getSearchRequest();
        List<String> remoteIndices = this.getRemoteIndices(searchRequest);
        if (!remoteIndices.isEmpty()) {
            return this.rewriteQueryAgainstKnnField((QueryRewriteContext)queryRewriteContext);
        }
        Map<String, NeuralQueryTargetFieldConfig> indexToTargetFieldConfigMap = this.getIndexToTargetFieldConfigMap(searchRequest);
        NeuralQueryTargetFieldConfig firstTargetFieldConfig = this.getFirstTargetFieldConfig(indexToTargetFieldConfigMap);
        if (firstTargetFieldConfig == null || !Boolean.TRUE.equals(firstTargetFieldConfig.getIsSemanticField())) {
            return this.rewriteQueryAgainstKnnField((QueryRewriteContext)queryRewriteContext);
        }
        NeuralQueryBuilder.validateNeuralQueryBuilder(this, NeuralQueryBuildStage.REWRITE, Boolean.TRUE, firstTargetFieldConfig.getEmbeddingFieldType());
        boolean isModelGeneratedEmbeddingAvailable = this.modelIdToVectorSupplierMap != null || this.modelIdToQueryTokensSupplierMap != null;
        boolean canUseSearchAnalyzerForSingleTargetIndex = this.queryTokensMapSupplier == null && this.modelId == null && this.getSearchAnalyzer(firstTargetFieldConfig) != null;
        boolean bl = canRewriteSingleTargetIndex = isModelGeneratedEmbeddingAvailable || this.queryTokensMapSupplier != null || canUseSearchAnalyzerForSingleTargetIndex || this.isSparseTwoPhaseTwo();
        if (indexToTargetFieldConfigMap.size() == 1 && canRewriteSingleTargetIndex) {
            return this.rewriteQueryForSemanticField(firstTargetFieldConfig);
        }
        boolean canUseAnalyzerForAllTargetIndices = this.searchAnalyzer != null || this.queryTokensMapSupplier == null && this.modelId == null && indexToTargetFieldConfigMap.values().stream().filter((? super T config) -> config.getSemanticFieldSearchAnalyzer() == null).toList().isEmpty();
        boolean bl2 = canSkipInference = isModelGeneratedEmbeddingAvailable || this.queryTokensMapSupplier != null || canUseAnalyzerForAllTargetIndices || this.isSparseTwoPhaseTwo();
        if (canSkipInference) {
            return this;
        }
        Set<String> modelIds = indexToTargetFieldConfigMap.values().stream().filter((? super T config) -> config.getSemanticFieldSearchAnalyzer() == null).map(NeuralQueryTargetFieldConfig::getSearchModelId).collect(Collectors.toSet());
        return this.inferenceForSemanticField((QueryRewriteContext)queryRewriteContext, modelIds, firstTargetFieldConfig.getEmbeddingFieldType());
    }

    private QueryBuilder rewriteQueryForSemanticField(@NonNull NeuralQueryTargetFieldConfig targetFieldConfig) {
        Objects.requireNonNull(targetFieldConfig, "targetFieldConfig is marked non-null but is null");
        String searchModelId = this.getSearchModelId(targetFieldConfig);
        String semanticFieldSearchAnalyzer = this.getSearchAnalyzer(targetFieldConfig);
        Boolean chunkingEnabled = targetFieldConfig.getChunkingEnabled();
        String embeddingFieldType = targetFieldConfig.getEmbeddingFieldType();
        String embeddingFieldPath = targetFieldConfig.getEmbeddingFieldPath();
        String chunksPath = targetFieldConfig.getChunksPath();
        if ("knn_vector".equals(embeddingFieldType)) {
            EventStatsManager.increment(EventStatName.NEURAL_QUERY_AGAINST_SEMANTIC_DENSE_REQUESTS);
            if (this.modelIdToVectorSupplierMap == null || this.modelIdToVectorSupplierMap.get(searchModelId) == null || this.modelIdToVectorSupplierMap.get(searchModelId).get() == null) {
                throw new RuntimeException(this.getErrorMessageWithBaseErrorForSemantic("Not able to find the dense embedding when try to rewrite it to the KNN query."));
            }
            float[] vector = this.modelIdToVectorSupplierMap.get(searchModelId).get();
            QueryBuilder knnQueryBuilder = this.createKNNQueryBuilder(embeddingFieldPath, vector);
            if (Boolean.TRUE.equals(chunkingEnabled)) {
                return new NestedQueryBuilder(chunksPath, knnQueryBuilder, ScoreMode.Max);
            }
            return knnQueryBuilder;
        }
        if ("rank_features".equals(embeddingFieldType)) {
            boolean useModelGeneratedEmbedding;
            EventStatsManager.increment(EventStatName.NEURAL_QUERY_AGAINST_SEMANTIC_SPARSE_REQUESTS);
            Supplier<Map<String, Float>> queryTokensSupplier = this.queryTokensMapSupplier;
            boolean bl = useModelGeneratedEmbedding = this.modelId != null || this.queryTokensMapSupplier == null && semanticFieldSearchAnalyzer == null;
            if (useModelGeneratedEmbedding) {
                if (this.isSparseTwoPhaseTwo()) {
                    queryTokensSupplier = () -> this.modelIdToTwoPhaseSharedQueryTokenSupplier.get().get(searchModelId);
                } else {
                    if (this.modelIdToQueryTokensSupplierMap == null || this.modelIdToQueryTokensSupplierMap.get(searchModelId) == null) {
                        throw new RuntimeException(this.getErrorMessageWithBaseErrorForSemantic("Not able to find the sparse embedding when try to rewrite it to neural sparse query."));
                    }
                    queryTokensSupplier = this.modelIdToQueryTokensSupplierMap.get(searchModelId);
                }
            }
            NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder)((NeuralSparseQueryBuilder)new NeuralSparseQueryBuilder().fieldName(embeddingFieldPath)).neuralSparseQueryTwoPhaseInfo(this.neuralSparseQueryTwoPhaseInfo);
            if (queryTokensSupplier != null) {
                neuralSparseQueryBuilder = (NeuralSparseQueryBuilder)neuralSparseQueryBuilder.queryTokensMapSupplier(queryTokensSupplier);
            } else if (semanticFieldSearchAnalyzer != null) {
                neuralSparseQueryBuilder = (NeuralSparseQueryBuilder)((NeuralSparseQueryBuilder)neuralSparseQueryBuilder.searchAnalyzer(semanticFieldSearchAnalyzer)).queryText(this.queryText);
            } else {
                throw new IllegalStateException(this.getErrorMessageWithBaseErrorForSemantic("Not able to find the embedding or tokenizer when try to rewrite it to neural sparse query."));
            }
            if (Boolean.TRUE.equals(chunkingEnabled)) {
                return new NestedQueryBuilder(chunksPath, (QueryBuilder)neuralSparseQueryBuilder, ScoreMode.Max);
            }
            return neuralSparseQueryBuilder;
        }
        throw new RuntimeException(this.getErrorMessageWithBaseErrorForSemantic("Expect the embedding field type to be knn_vector or ran_features but found unsupported embedding field type: " + embeddingFieldType));
    }

    private Map<String, NeuralQueryTargetFieldConfig> getIndexToTargetFieldConfigMap(@NonNull IndicesRequest searchRequest) {
        Objects.requireNonNull(searchRequest, "searchRequest is marked non-null but is null");
        List<IndexMetadata> targetIndexMetadataList = NeuralSearchClusterUtil.instance().getIndexMetadataList(searchRequest);
        Map<String, NeuralQueryTargetFieldConfig> indexToTargetFieldConfigMap = SemanticMappingUtils.getIndexToTargetFieldConfigMapFromIndexMetadata(this.fieldName, targetIndexMetadataList);
        NeuralQueryValidationUtil.validateTargetFieldConfig(this.fieldName, indexToTargetFieldConfigMap);
        return indexToTargetFieldConfigMap;
    }

    private List<String> getRemoteIndices(@NonNull IndicesRequest searchRequest) {
        Objects.requireNonNull(searchRequest, "searchRequest is marked non-null but is null");
        return Arrays.stream(searchRequest.indices()).filter((? super T index) -> index.indexOf(58) >= 0).collect(Collectors.toList());
    }

    public boolean isTargetSparseEmbedding(@NonNull IndicesRequest searchRequest) {
        Objects.requireNonNull(searchRequest, "searchRequest is marked non-null but is null");
        if (!this.getRemoteIndices(searchRequest).isEmpty()) {
            return false;
        }
        return this.getIndexToTargetFieldConfigMap(searchRequest).values().stream().anyMatch(config -> "rank_features".equals(config.getEmbeddingFieldType()));
    }

    private NeuralQueryTargetFieldConfig getFirstTargetFieldConfig(@NonNull Map<String, NeuralQueryTargetFieldConfig> indexToTargetFieldConfigMap) {
        Objects.requireNonNull(indexToTargetFieldConfigMap, "indexToTargetFieldConfigMap is marked non-null but is null");
        Set<String> targetIndices = indexToTargetFieldConfigMap.keySet();
        if (targetIndices.isEmpty()) {
            return null;
        }
        return indexToTargetFieldConfigMap.get(targetIndices.iterator().next());
    }

    private QueryBuilder rewriteQueryAgainstKnnField(QueryRewriteContext queryRewriteContext) {
        if (this.vectorSupplier() != null) {
            if (this.vectorSupplier().get() == null) {
                return this;
            }
            EventStatsManager.increment(EventStatName.NEURAL_QUERY_AGAINST_KNN_REQUESTS);
            return this.createKNNQueryBuilder(this.fieldName(), this.vectorSupplier.get());
        }
        SetOnce vectorSetOnce = new SetOnce();
        HashMap<String, String> inferenceInput = new HashMap<String, String>();
        NeuralQueryBuilder neuralQueryBuilder = this.createNeuralQueryBuilder("knn_vector", () -> ((SetOnce)vectorSetOnce).get(), false);
        if (StringUtils.isNotBlank((CharSequence)this.queryText())) {
            inferenceInput.put("inputText", this.queryText());
        }
        if (StringUtils.isNotBlank((CharSequence)this.queryImage())) {
            inferenceInput.put("inputImage", this.queryImage());
        }
        queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentencesMap((MapInferenceRequest)((MapInferenceRequest.MapInferenceRequestBuilder)((InferenceRequest.InferenceRequestBuilder)((MapInferenceRequest.MapInferenceRequestBuilder)MapInferenceRequest.builder().modelId(this.modelId())).inputObjects(inferenceInput)).embeddingContentType(EmbeddingContentType.QUERY)).build(), (ActionListener<List<Number>>)ActionListener.wrap(floatList -> {
            vectorSetOnce.set((Object)VectorUtil.vectorAsListToArray(floatList));
            actionListener.onResponse(null);
        }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        return neuralQueryBuilder;
    }

    QueryBuilder createKNNQueryBuilder(String fieldName, float[] vector) {
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForNeuralKNNQueryBuilder()) {
            NeuralKNNQueryBuilder.Builder builder = NeuralKNNQueryBuilder.builder().fieldName(fieldName).vector(vector).filter(this.queryfilter()).expandNested(this.expandNested()).methodParameters(this.methodParameters()).rescoreContext(this.rescoreContext()).originalQueryText(this.queryText()).k(this.k()).maxDistance(this.maxDistance()).minScore(this.minScore());
            return builder.build();
        }
        return KNNQueryBuilder.builder().fieldName(fieldName).vector(vector).filter(this.queryfilter()).maxDistance(this.maxDistance()).minScore(this.minScore()).expandNested(this.expandNested()).k(this.k()).methodParameters(this.methodParameters()).rescoreContext(this.rescoreContext()).build();
    }

    private NeuralQueryBuilder createNeuralQueryBuilder(String embeddingFieldType, Supplier<float[]> vectorSupplier, boolean isSemanticField) {
        return NeuralQueryBuilder.builder().fieldName(this.fieldName()).queryText(this.queryText()).modelId(this.modelId()).embeddingFieldType(embeddingFieldType).queryImage(this.queryImage()).k(this.k()).maxDistance(this.maxDistance()).minScore(this.minScore()).expandNested(this.expandNested()).vectorSupplier(vectorSupplier).filter(this.queryfilter()).methodParameters(this.methodParameters()).rescoreContext(this.rescoreContext()).isSemanticField(isSemanticField).buildStage(NeuralQueryBuildStage.REWRITE).queryTokensMapSupplier(this.queryTokensMapSupplier()).modelIdToQueryTokensSupplierMap(this.modelIdToQueryTokensSupplierMap()).modelIdToVectorSupplierMap(this.modelIdToVectorSupplierMap()).searchAnalyzer(this.searchAnalyzer()).neuralSparseQueryTwoPhaseInfo(this.neuralSparseQueryTwoPhaseInfo()).modelIdToTwoPhaseSharedQueryToken(this.modelIdToTwoPhaseSharedQueryToken()).modelIdToTwoPhaseSharedQueryTokenSupplier(this.modelIdToTwoPhaseSharedQueryTokenSupplier()).build();
    }

    private QueryBuilder rewriteQueryWithQueryShardContext(QueryShardContext shardContext) {
        String embeddingFieldPath;
        MappedFieldType mappedFieldType = shardContext.fieldMapper(this.fieldName());
        if (mappedFieldType == null) {
            return this;
        }
        if ("knn_vector".equals(mappedFieldType.typeName())) {
            return this.rewriteQueryAgainstKnnField((QueryRewriteContext)shardContext);
        }
        if (!"semantic".equals(mappedFieldType.typeName())) {
            throw new RuntimeException("Expect the neural query target field to be a semantic field but found: " + mappedFieldType.typeName());
        }
        SemanticFieldMapper.SemanticFieldType semanticFieldType = (SemanticFieldMapper.SemanticFieldType)mappedFieldType;
        Boolean chunkingEnabled = semanticFieldType.getSemanticParameters().isChunkingEnabled();
        String semanticFieldSearchAnalyzer = semanticFieldType.getSemanticParameters().getSemanticFieldSearchAnalyzer();
        NeuralQueryTargetFieldConfig.NeuralQueryTargetFieldConfigBuilder targetFieldConfigBuilder = NeuralQueryTargetFieldConfig.builder().isSemanticField(true).searchModelId(this.getSearchModelId(semanticFieldType)).chunkingEnabled(chunkingEnabled).semanticFieldSearchAnalyzer(semanticFieldSearchAnalyzer);
        String semanticInfoFieldPath = semanticFieldType.getSemanticInfoFieldPath();
        if (Boolean.TRUE.equals(chunkingEnabled)) {
            String chunksPath = semanticInfoFieldPath + ".chunks";
            targetFieldConfigBuilder.chunksPath(chunksPath);
            embeddingFieldPath = chunksPath + ".embedding";
        } else {
            embeddingFieldPath = semanticInfoFieldPath + ".embedding";
        }
        targetFieldConfigBuilder.embeddingFieldPath(embeddingFieldPath);
        MappedFieldType embeddingFieldType = shardContext.fieldMapper(embeddingFieldPath);
        if (embeddingFieldType == null) {
            throw new RuntimeException(this.getErrorMessageWithBaseErrorForSemantic("Expect the embedding field exists in the index mapping but not able to find it."));
        }
        String embeddingFieldTypeName = embeddingFieldType.typeName();
        NeuralQueryTargetFieldConfig targetFieldConfig = targetFieldConfigBuilder.embeddingFieldType(embeddingFieldTypeName).build();
        NeuralQueryBuilder.validateNeuralQueryBuilder(this, NeuralQueryBuildStage.REWRITE, Boolean.TRUE, embeddingFieldTypeName);
        if (this.modelIdToVectorSupplierMap == null && this.modelIdToQueryTokensSupplierMap == null && this.queryTokensMapSupplier == null && this.getSearchAnalyzer(targetFieldConfig) == null && !this.isSparseTwoPhaseTwo()) {
            return this.inferenceForSemanticField((QueryRewriteContext)shardContext, Set.of(targetFieldConfig.getSearchModelId()), embeddingFieldTypeName);
        }
        return this.rewriteQueryForSemanticField(targetFieldConfig);
    }

    private String getSearchModelId(@NonNull SemanticFieldMapper.SemanticFieldType semanticFieldType) {
        Objects.requireNonNull(semanticFieldType, "semanticFieldType is marked non-null but is null");
        if (semanticFieldType.getSemanticParameters().getSearchModelId() != null) {
            return semanticFieldType.getSemanticParameters().getSearchModelId();
        }
        return semanticFieldType.getSemanticParameters().getModelId();
    }

    private String getSearchModelId(@NonNull NeuralQueryTargetFieldConfig config) {
        Objects.requireNonNull(config, "config is marked non-null but is null");
        if (this.modelId != null) {
            return this.modelId;
        }
        return config.getSearchModelId();
    }

    private String getSearchAnalyzer(@NonNull NeuralQueryTargetFieldConfig config) {
        Objects.requireNonNull(config, "config is marked non-null but is null");
        if (this.searchAnalyzer != null) {
            return this.searchAnalyzer;
        }
        return config.getSemanticFieldSearchAnalyzer();
    }

    private String getErrorMessageWithBaseErrorForSemantic(@NonNull String errorMessage) {
        Objects.requireNonNull(errorMessage, "errorMessage is marked non-null but is null");
        return "Failed to rewrite the neural query against the semantic field " + this.fieldName + ". " + errorMessage;
    }

    private QueryBuilder inferenceForSemanticField(@NonNull QueryRewriteContext queryRewriteContext, @NonNull Set<String> modelIdsFromTargetFields, @NonNull String embeddingFieldType) {
        Objects.requireNonNull(queryRewriteContext, "queryRewriteContext is marked non-null but is null");
        Objects.requireNonNull(modelIdsFromTargetFields, "modelIdsFromTargetFields is marked non-null but is null");
        Objects.requireNonNull(embeddingFieldType, "embeddingFieldType is marked non-null but is null");
        Set<String> modelIds = modelIdsFromTargetFields;
        if (this.modelId != null) {
            modelIds = Set.of(this.modelId);
        }
        if ("knn_vector".equals(embeddingFieldType)) {
            this.modelIdToVectorSupplierMap = new HashMap<String, Supplier<float[]>>(modelIds.size());
        } else if ("rank_features".equals(embeddingFieldType)) {
            this.modelIdToQueryTokensSupplierMap = new HashMap<String, Supplier<Map<String, Float>>>(modelIds.size());
            if (this.isSparseTwoPhaseOne()) {
                this.modelIdToTwoPhaseSharedQueryToken = new HashMap<String, Map<String, Float>>(modelIds.size());
            }
        } else {
            throw new RuntimeException(String.format(Locale.ROOT, "Not able to do inference for the neural query against the field %s. Unsupported embedding field type: %s.", this.fieldName, embeddingFieldType));
        }
        NeuralQueryBuilder neuralQueryBuilder = this.createNeuralQueryBuilder(embeddingFieldType, this.vectorSupplier(), true);
        if ("knn_vector".equals(embeddingFieldType)) {
            this.inferenceByDenseModel(modelIds, queryRewriteContext);
        } else if (this.queryTokensMapSupplier == null) {
            this.inferenceBySparseModel(modelIds, queryRewriteContext);
        }
        return neuralQueryBuilder;
    }

    private void inferenceBySparseModel(@NonNull Set<String> modelIds, @NonNull QueryRewriteContext queryRewriteContext) {
        Objects.requireNonNull(modelIds, "modelIds is marked non-null but is null");
        Objects.requireNonNull(queryRewriteContext, "queryRewriteContext is marked non-null but is null");
        for (String modelId : modelIds) {
            SetOnce setOnce = new SetOnce();
            this.modelIdToQueryTokensSupplierMap.put(modelId, () -> ((SetOnce)setOnce).get());
            queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((InferenceRequest.InferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(modelId)).inputTexts(List.of(this.queryText))).embeddingContentType(EmbeddingContentType.QUERY)).build(), ActionListener.wrap(mapResultList -> {
                Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
                if (this.isSparseTwoPhaseOne()) {
                    Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.splitSparseVector(this.neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType(), this.neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio(), queryTokens);
                    setOnce.set((Object)((Map)splitQueryTokens.v1()));
                    this.modelIdToTwoPhaseSharedQueryToken.put(modelId, (Map)splitQueryTokens.v2());
                } else {
                    setOnce.set(queryTokens);
                }
                actionListener.onResponse(null);
            }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        }
    }

    private void inferenceByDenseModel(@NonNull Set<String> modelIds, @NonNull QueryRewriteContext queryRewriteContext) {
        Objects.requireNonNull(modelIds, "modelIds is marked non-null but is null");
        Objects.requireNonNull(queryRewriteContext, "queryRewriteContext is marked non-null but is null");
        Map<String, String> inferenceInput = this.getInferenceInputForDenseModel();
        for (String modelId : modelIds) {
            SetOnce vectorSetOnce = new SetOnce();
            this.modelIdToVectorSupplierMap.put(modelId, () -> ((SetOnce)vectorSetOnce).get());
            queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentencesMap((MapInferenceRequest)((MapInferenceRequest.MapInferenceRequestBuilder)((InferenceRequest.InferenceRequestBuilder)((MapInferenceRequest.MapInferenceRequestBuilder)MapInferenceRequest.builder().modelId(modelId)).inputObjects(inferenceInput)).embeddingContentType(EmbeddingContentType.QUERY)).build(), (ActionListener<List<Number>>)ActionListener.wrap(floatList -> {
                vectorSetOnce.set((Object)VectorUtil.vectorAsListToArray(floatList));
                actionListener.onResponse(null);
            }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        }
    }

    private Map<String, String> getInferenceInputForDenseModel() {
        HashMap<String, String> inferenceInput = new HashMap<String, String>();
        if (StringUtils.isNotBlank((CharSequence)this.queryText())) {
            inferenceInput.put("inputText", this.queryText());
        }
        if (StringUtils.isNotBlank((CharSequence)this.queryImage())) {
            inferenceInput.put("inputImage", this.queryImage());
        }
        return inferenceInput;
    }

    protected Query doToQuery(QueryShardContext queryShardContext) {
        MappedFieldType mappedFieldType = queryShardContext.fieldMapper(this.fieldName);
        if (mappedFieldType == null) {
            return new MatchNoDocsQuery();
        }
        throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
    }

    protected boolean doEquals(NeuralQueryBuilder obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder();
        equalsBuilder.append((Object)this.fieldName, (Object)obj.fieldName);
        equalsBuilder.append((Object)this.queryText, (Object)obj.queryText);
        equalsBuilder.append((Object)this.queryImage, (Object)obj.queryImage);
        equalsBuilder.append((Object)this.modelId, (Object)obj.modelId);
        equalsBuilder.append((Object)this.searchAnalyzer, (Object)obj.searchAnalyzer);
        equalsBuilder.append((Object)this.k, (Object)obj.k);
        equalsBuilder.append((Object)this.maxDistance, (Object)obj.maxDistance);
        equalsBuilder.append((Object)this.minScore, (Object)obj.minScore);
        equalsBuilder.append((Object)this.expandNested, (Object)obj.expandNested);
        equalsBuilder.append(this.getVector(this.vectorSupplier), this.getVector(obj.vectorSupplier));
        equalsBuilder.append((Object)this.queryfilter, (Object)obj.queryfilter);
        equalsBuilder.append(this.methodParameters, obj.methodParameters);
        equalsBuilder.append((Object)this.rescoreContext, (Object)obj.rescoreContext);
        equalsBuilder.append(this.getQueryTokenMap(this.queryTokensMapSupplier), this.getQueryTokenMap(obj.queryTokensMapSupplier));
        equalsBuilder.append((Object)this.neuralSparseQueryTwoPhaseInfo, (Object)obj.neuralSparseQueryTwoPhaseInfo);
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.queryText, this.queryImage, this.modelId, this.searchAnalyzer, this.k, this.maxDistance, this.minScore, this.expandNested, Arrays.hashCode(this.getVector(this.vectorSupplier)), this.queryfilter, this.methodParameters, this.rescoreContext, this.getQueryTokenMap(this.queryTokensMapSupplier), this.neuralSparseQueryTwoPhaseInfo);
    }

    private float[] getVector(Supplier<float[]> vectorSupplier) {
        return Objects.isNull(vectorSupplier) ? null : vectorSupplier.get();
    }

    private Map<String, Float> getQueryTokenMap(Supplier<Map<String, Float>> queryTokensSupplierMap) {
        return Objects.isNull(queryTokensSupplierMap) ? null : queryTokensSupplierMap.get();
    }

    public String getWriteableName() {
        return NAME;
    }

    @Override
    public NeuralQueryBuilder prepareTwoPhaseQuery(float pruneRatio, PruneType pruneType) {
        this.neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.PHASE_ONE, pruneRatio, pruneType);
        NeuralQueryBuilder copy = (NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)new NeuralQueryBuilder().queryName(this.queryName)).fieldName(this.fieldName)).queryText(this.queryText)).modelId(this.modelId)).searchAnalyzer(this.searchAnalyzer)).neuralSparseQueryTwoPhaseInfo(new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.PHASE_TWO, pruneRatio, pruneType));
        if (Objects.nonNull(this.queryTokensMapSupplier)) {
            Map tokens = (Map)this.queryTokensMapSupplier.get();
            Tuple<Map<String, Float>, Map<String, Float>> splitTokens = PruneUtils.splitSparseVector(pruneType, pruneRatio, tokens);
            this.queryTokensMapSupplier(() -> splitTokens.v1());
            copy.queryTokensMapSupplier(() -> splitTokens.v2());
        } else {
            copy.modelIdToTwoPhaseSharedQueryTokenSupplier(() -> this.modelIdToTwoPhaseSharedQueryToken);
        }
        return copy;
    }

    @Generated
    public String embeddingFieldType() {
        return this.embeddingFieldType;
    }

    @Generated
    public String queryImage() {
        return this.queryImage;
    }

    @Generated
    public Integer k() {
        return this.k;
    }

    @Generated
    public Float maxDistance() {
        return this.maxDistance;
    }

    @Generated
    public Float minScore() {
        return this.minScore;
    }

    @Generated
    public Boolean expandNested() {
        return this.expandNested;
    }

    @Generated
    public QueryBuilder queryfilter() {
        return this.queryfilter;
    }

    @Generated
    public Map<String, ?> methodParameters() {
        return this.methodParameters;
    }

    @Generated
    public RescoreContext rescoreContext() {
        return this.rescoreContext;
    }

    @Generated
    public Map<String, Supplier<float[]>> modelIdToVectorSupplierMap() {
        return this.modelIdToVectorSupplierMap;
    }

    @Generated
    public Map<String, Supplier<Map<String, Float>>> modelIdToQueryTokensSupplierMap() {
        return this.modelIdToQueryTokensSupplierMap;
    }

    @Generated
    public Map<String, Map<String, Float>> modelIdToTwoPhaseSharedQueryToken() {
        return this.modelIdToTwoPhaseSharedQueryToken;
    }

    @Generated
    public Supplier<Map<String, Map<String, Float>>> modelIdToTwoPhaseSharedQueryTokenSupplier() {
        return this.modelIdToTwoPhaseSharedQueryTokenSupplier;
    }

    @Generated
    public NeuralQueryBuilder embeddingFieldType(String embeddingFieldType) {
        this.embeddingFieldType = embeddingFieldType;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryImage(String queryImage) {
        this.queryImage = queryImage;
        return this;
    }

    @Generated
    public NeuralQueryBuilder k(Integer k) {
        this.k = k;
        return this;
    }

    @Generated
    public NeuralQueryBuilder maxDistance(Float maxDistance) {
        this.maxDistance = maxDistance;
        return this;
    }

    @Generated
    public NeuralQueryBuilder minScore(Float minScore) {
        this.minScore = minScore;
        return this;
    }

    @Generated
    public NeuralQueryBuilder expandNested(Boolean expandNested) {
        this.expandNested = expandNested;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryfilter(QueryBuilder queryfilter) {
        this.queryfilter = queryfilter;
        return this;
    }

    @Generated
    public NeuralQueryBuilder methodParameters(Map<String, ?> methodParameters) {
        this.methodParameters = methodParameters;
        return this;
    }

    @Generated
    public NeuralQueryBuilder rescoreContext(RescoreContext rescoreContext) {
        this.rescoreContext = rescoreContext;
        return this;
    }

    @Generated
    public NeuralQueryBuilder modelIdToVectorSupplierMap(Map<String, Supplier<float[]>> modelIdToVectorSupplierMap) {
        this.modelIdToVectorSupplierMap = modelIdToVectorSupplierMap;
        return this;
    }

    @Generated
    public NeuralQueryBuilder modelIdToQueryTokensSupplierMap(Map<String, Supplier<Map<String, Float>>> modelIdToQueryTokensSupplierMap) {
        this.modelIdToQueryTokensSupplierMap = modelIdToQueryTokensSupplierMap;
        return this;
    }

    @Generated
    public NeuralQueryBuilder modelIdToTwoPhaseSharedQueryToken(Map<String, Map<String, Float>> modelIdToTwoPhaseSharedQueryToken) {
        this.modelIdToTwoPhaseSharedQueryToken = modelIdToTwoPhaseSharedQueryToken;
        return this;
    }

    @Generated
    public NeuralQueryBuilder modelIdToTwoPhaseSharedQueryTokenSupplier(Supplier<Map<String, Map<String, Float>>> modelIdToTwoPhaseSharedQueryTokenSupplier) {
        this.modelIdToTwoPhaseSharedQueryTokenSupplier = modelIdToTwoPhaseSharedQueryTokenSupplier;
        return this;
    }

    @Generated
    private NeuralQueryBuilder() {
    }

    @Generated
    Supplier<float[]> vectorSupplier() {
        return this.vectorSupplier;
    }

    @Generated
    NeuralQueryBuilder vectorSupplier(Supplier<float[]> vectorSupplier) {
        this.vectorSupplier = vectorSupplier;
        return this;
    }

    public static class Builder {
        private String fieldName;
        private String queryText;
        private String queryImage;
        private String modelId;
        private String searchAnalyzer;
        private Integer k = null;
        private Float maxDistance = null;
        private Float minScore = null;
        private Boolean expandNested;
        private Supplier<float[]> vectorSupplier;
        private QueryBuilder filter;
        private Map<String, ?> methodParameters;
        private RescoreContext rescoreContext;
        private String queryName;
        private float boost = 1.0f;
        private String embeddingFieldType;
        private Map<String, Supplier<float[]>> modelIdToVectorSupplierMap;
        private Supplier<Map<String, Float>> queryTokensMapSupplier;
        private Map<String, Supplier<Map<String, Float>>> modelIdToQueryTokensSupplierMap;
        private Boolean isSemanticField = false;
        private NeuralQueryBuildStage buildStage;
        private NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo;
        private Map<String, Map<String, Float>> modelIdToTwoPhaseSharedQueryToken;
        private Supplier<Map<String, Map<String, Float>>> modelIdToTwoPhaseSharedQueryTokenSupplier;

        public Builder fieldName(String fieldName) {
            this.fieldName = fieldName;
            return this;
        }

        public Builder queryText(String queryText) {
            this.queryText = queryText;
            return this;
        }

        public Builder queryImage(String queryImage) {
            this.queryImage = queryImage;
            return this;
        }

        public Builder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        public Builder k(Integer k) {
            this.k = k;
            return this;
        }

        public Builder maxDistance(Float maxDistance) {
            this.maxDistance = maxDistance;
            return this;
        }

        public Builder minScore(Float minScore) {
            this.minScore = minScore;
            return this;
        }

        public Builder expandNested(Boolean expandNested) {
            this.expandNested = expandNested;
            return this;
        }

        public Builder vectorSupplier(Supplier<float[]> vectorSupplier) {
            this.vectorSupplier = vectorSupplier;
            return this;
        }

        public Builder filter(QueryBuilder filter) {
            this.filter = filter;
            return this;
        }

        public Builder methodParameters(Map<String, ?> methodParameters) {
            this.methodParameters = methodParameters;
            return this;
        }

        public Builder queryName(String queryName) {
            this.queryName = queryName;
            return this;
        }

        public Builder boost(float boost) {
            this.boost = boost;
            return this;
        }

        public Builder rescoreContext(RescoreContext rescoreContext) {
            this.rescoreContext = rescoreContext;
            return this;
        }

        public Builder embeddingFieldType(String embeddingFieldType) {
            this.embeddingFieldType = embeddingFieldType;
            return this;
        }

        public Builder modelIdToVectorSupplierMap(Map<String, Supplier<float[]>> modelIdToVectorSupplierMap) {
            this.modelIdToVectorSupplierMap = modelIdToVectorSupplierMap;
            return this;
        }

        public Builder queryTokensMapSupplier(Supplier<Map<String, Float>> queryTokensMapSupplier) {
            this.queryTokensMapSupplier = queryTokensMapSupplier;
            return this;
        }

        public Builder modelIdToQueryTokensSupplierMap(Map<String, Supplier<Map<String, Float>>> modelIdToQueryTokensSupplierMap) {
            this.modelIdToQueryTokensSupplierMap = modelIdToQueryTokensSupplierMap;
            return this;
        }

        public Builder isSemanticField(Boolean isSemanticField) {
            this.isSemanticField = isSemanticField;
            return this;
        }

        public Builder buildStage(NeuralQueryBuildStage buildStage) {
            this.buildStage = buildStage;
            return this;
        }

        public Builder searchAnalyzer(String searchAnalyzer) {
            this.searchAnalyzer = searchAnalyzer;
            return this;
        }

        public Builder neuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo) {
            this.neuralSparseQueryTwoPhaseInfo = neuralSparseQueryTwoPhaseInfo;
            return this;
        }

        public Builder modelIdToTwoPhaseSharedQueryToken(Map<String, Map<String, Float>> modelIdToTwoPhaseSharedQueryToken) {
            this.modelIdToTwoPhaseSharedQueryToken = modelIdToTwoPhaseSharedQueryToken;
            return this;
        }

        public Builder modelIdToTwoPhaseSharedQueryTokenSupplier(Supplier<Map<String, Map<String, Float>>> modelIdToTwoPhaseSharedQueryTokenSupplier) {
            this.modelIdToTwoPhaseSharedQueryTokenSupplier = modelIdToTwoPhaseSharedQueryTokenSupplier;
            return this;
        }

        public NeuralQueryBuilder build() {
            NeuralQueryBuilder.requireValue((Object)this.fieldName, (String)"Field name must be provided for neural query");
            NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)((NeuralQueryBuilder)new NeuralQueryBuilder().fieldName(this.fieldName)).queryText(this.queryText)).modelId(this.modelId)).embeddingFieldType(this.embeddingFieldType).queryImage(this.queryImage).k(this.k).maxDistance(this.maxDistance).minScore(this.minScore).expandNested(this.expandNested).vectorSupplier(this.vectorSupplier).queryfilter(this.filter).methodParameters(this.methodParameters).rescoreContext(this.rescoreContext).modelIdToVectorSupplierMap(this.modelIdToVectorSupplierMap).queryTokensMapSupplier(this.queryTokensMapSupplier)).modelIdToQueryTokensSupplierMap(this.modelIdToQueryTokensSupplierMap).searchAnalyzer(this.searchAnalyzer)).neuralSparseQueryTwoPhaseInfo(this.neuralSparseQueryTwoPhaseInfo == null ? new NeuralSparseQueryTwoPhaseInfo() : this.neuralSparseQueryTwoPhaseInfo)).modelIdToTwoPhaseSharedQueryToken(this.modelIdToTwoPhaseSharedQueryToken).modelIdToTwoPhaseSharedQueryTokenSupplier(this.modelIdToTwoPhaseSharedQueryTokenSupplier).boost(this.boost)).queryName(this.queryName);
            NeuralQueryBuilder.validateNeuralQueryBuilder(neuralQueryBuilder, this.buildStage, this.isSemanticField, this.embeddingFieldType);
            return neuralQueryBuilder;
        }
    }
}

