/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.googleai;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.http.client.HttpClientBuilder;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.googleai.GeminiContent;
import dev.langchain4j.model.googleai.GeminiEmbeddingRequestResponse;
import dev.langchain4j.model.googleai.GeminiService;
import dev.langchain4j.model.output.Response;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;

public class GoogleAiEmbeddingModel
extends DimensionAwareEmbeddingModel {
    private static final int MAX_NUMBER_OF_SEGMENTS_PER_BATCH = 100;
    private final GeminiService geminiService;
    private final String modelName;
    private final Integer maxRetries;
    private final TaskType taskType;
    private final String titleMetadataKey;
    private final Integer outputDimensionality;

    public GoogleAiEmbeddingModel(GoogleAiEmbeddingModelBuilder builder) {
        ValidationUtils.ensureNotBlank((String)builder.apiKey, (String)"apiKey");
        this.geminiService = new GeminiService(builder.httpClientBuilder, builder.apiKey, builder.baseUrl, (Boolean)Utils.getOrDefault((Object)builder.logRequestsAndResponses, (Object)false), (Boolean)Utils.getOrDefault((Object)builder.logRequests, (Object)false), (Boolean)Utils.getOrDefault((Object)builder.logResponses, (Object)false), builder.logger, builder.timeout);
        this.modelName = ValidationUtils.ensureNotBlank((String)builder.modelName, (String)"modelName");
        this.maxRetries = (Integer)Utils.getOrDefault((Object)builder.maxRetries, (Object)2);
        this.taskType = builder.taskType;
        this.titleMetadataKey = (String)Utils.getOrDefault((Object)builder.titleMetadataKey, (Object)"title");
        this.outputDimensionality = builder.outputDimensionality;
    }

    public static GoogleAiEmbeddingModelBuilder builder() {
        return new GoogleAiEmbeddingModelBuilder();
    }

    public Response<Embedding> embed(TextSegment textSegment) {
        GeminiEmbeddingRequestResponse.GeminiEmbeddingRequest embeddingRequest = this.getGoogleAiEmbeddingRequest(textSegment);
        GeminiEmbeddingRequestResponse.GeminiEmbeddingResponse geminiResponse = (GeminiEmbeddingRequestResponse.GeminiEmbeddingResponse)RetryUtils.withRetryMappingExceptions(() -> this.geminiService.embed(this.modelName, embeddingRequest), (int)this.maxRetries);
        return Response.from((Object)Embedding.from(geminiResponse.embedding().values()));
    }

    public Response<Embedding> embed(String text) {
        return this.embed(TextSegment.from((String)text));
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        List embeddingRequests = textSegments.stream().map(this::getGoogleAiEmbeddingRequest).collect(Collectors.toList());
        ArrayList<Embedding> allEmbeddings = new ArrayList<Embedding>();
        int numberOfEmbeddings = embeddingRequests.size();
        int numberOfBatches = 1 + numberOfEmbeddings / 100;
        for (int i = 0; i < numberOfBatches; ++i) {
            int startIndex = 100 * i;
            int lastIndex = Math.min(startIndex + 100, numberOfEmbeddings);
            if (startIndex >= numberOfEmbeddings) break;
            GeminiEmbeddingRequestResponse.GeminiBatchEmbeddingRequest batchEmbeddingRequest = new GeminiEmbeddingRequestResponse.GeminiBatchEmbeddingRequest(embeddingRequests.subList(startIndex, lastIndex));
            GeminiEmbeddingRequestResponse.GeminiBatchEmbeddingResponse geminiResponse = (GeminiEmbeddingRequestResponse.GeminiBatchEmbeddingResponse)RetryUtils.withRetryMappingExceptions(() -> this.geminiService.batchEmbed(this.modelName, batchEmbeddingRequest));
            allEmbeddings.addAll(geminiResponse.embeddings().stream().map(values -> Embedding.from(values.values())).toList());
        }
        return Response.from(allEmbeddings);
    }

    public String modelName() {
        return this.modelName;
    }

    private GeminiEmbeddingRequestResponse.GeminiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
        GeminiContent.GeminiPart geminiPart = GeminiContent.GeminiPart.builder().text(textSegment.text()).build();
        GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
        String title = null;
        if (TaskType.RETRIEVAL_DOCUMENT.equals((Object)this.taskType) && textSegment.metadata() != null && textSegment.metadata().getString(this.titleMetadataKey) != null) {
            title = textSegment.metadata().getString(this.titleMetadataKey);
        }
        return new GeminiEmbeddingRequestResponse.GeminiEmbeddingRequest("models/" + this.modelName, content, this.taskType, title, this.outputDimensionality);
    }

    public Integer knownDimension() {
        return this.outputDimensionality;
    }

    public static class GoogleAiEmbeddingModelBuilder
    extends BaseGoogleAiEmbeddingModelBuilder<GoogleAiEmbeddingModelBuilder> {
        public GoogleAiEmbeddingModel build() {
            return new GoogleAiEmbeddingModel(this);
        }
    }

    public static enum TaskType {
        RETRIEVAL_QUERY,
        RETRIEVAL_DOCUMENT,
        SEMANTIC_SIMILARITY,
        CLASSIFICATION,
        CLUSTERING,
        QUESTION_ANSWERING,
        FACT_VERIFICATION;

    }

    static abstract class BaseGoogleAiEmbeddingModelBuilder<B extends BaseGoogleAiEmbeddingModelBuilder<B>> {
        HttpClientBuilder httpClientBuilder;
        String modelName;
        String apiKey;
        String baseUrl;
        Integer maxRetries;
        TaskType taskType;
        String titleMetadataKey;
        Integer outputDimensionality;
        Duration timeout;
        Boolean logRequestsAndResponses;
        Boolean logRequests;
        Boolean logResponses;
        Logger logger;

        BaseGoogleAiEmbeddingModelBuilder() {
        }

        public B httpClientBuilder(HttpClientBuilder httpClientBuilder) {
            this.httpClientBuilder = httpClientBuilder;
            return this.builder();
        }

        protected B builder() {
            return (B)this;
        }

        public B modelName(String modelName) {
            this.modelName = modelName;
            return this.builder();
        }

        public B apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this.builder();
        }

        public B baseUrl(String baseUrl) {
            this.baseUrl = baseUrl;
            return this.builder();
        }

        public B maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this.builder();
        }

        public B taskType(TaskType taskType) {
            this.taskType = taskType;
            return this.builder();
        }

        public B titleMetadataKey(String titleMetadataKey) {
            this.titleMetadataKey = titleMetadataKey;
            return this.builder();
        }

        public B outputDimensionality(Integer outputDimensionality) {
            this.outputDimensionality = outputDimensionality;
            return this.builder();
        }

        public B timeout(Duration timeout) {
            this.timeout = timeout;
            return this.builder();
        }

        public B logRequestsAndResponses(Boolean logRequestsAndResponses) {
            this.logRequestsAndResponses = logRequestsAndResponses;
            return this.builder();
        }

        public B logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this.builder();
        }

        public B logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this.builder();
        }

        public B logger(Logger logger) {
            this.logger = logger;
            return this.builder();
        }
    }
}

