/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer.memory;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.action.memorycontainer.memory.FactSearchResult;
import org.opensearch.ml.action.memorycontainer.memory.MemoryInfo;
import org.opensearch.ml.action.memorycontainer.memory.MemoryOperationsService;
import org.opensearch.ml.action.memorycontainer.memory.MemoryProcessingService;
import org.opensearch.ml.action.memorycontainer.memory.MemorySearchService;
import org.opensearch.ml.common.memorycontainer.MLMemory;
import org.opensearch.ml.common.memorycontainer.MLMemoryContainer;
import org.opensearch.ml.common.memorycontainer.MemoryDecision;
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
import org.opensearch.ml.common.memorycontainer.MemoryType;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryResult;
import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.MemoryContainerHelper;
import org.opensearch.ml.helper.MemoryEmbeddingHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportAddMemoriesAction
extends HandledTransportAction<MLAddMemoriesRequest, MLAddMemoriesResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportAddMemoriesAction.class);
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MemoryContainerHelper memoryContainerHelper;
    private final MemoryEmbeddingHelper memoryEmbeddingHelper;
    private final MemoryProcessingService memoryProcessingService;
    private final MemorySearchService memorySearchService;
    private final MemoryOperationsService memoryOperationsService;

    @Inject
    public TransportAddMemoriesAction(TransportService transportService, ActionFilters actionFilters, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, ConnectorAccessControlHelper connectorAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, MLModelManager mlModelManager, MemoryContainerHelper memoryContainerHelper, MemoryEmbeddingHelper memoryEmbeddingHelper) {
        super("cluster:admin/opensearch/ml/memory_containers/memories/add", transportService, actionFilters, MLAddMemoriesRequest::new);
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.memoryContainerHelper = memoryContainerHelper;
        this.memoryEmbeddingHelper = memoryEmbeddingHelper;
        this.memoryProcessingService = new MemoryProcessingService(client, xContentRegistry);
        this.memorySearchService = new MemorySearchService(client);
        this.memoryOperationsService = new MemoryOperationsService(client, memoryEmbeddingHelper);
    }

    TransportAddMemoriesAction(TransportService transportService, ActionFilters actionFilters, Client client, NamedXContentRegistry xContentRegistry, MLFeatureEnabledSetting mlFeatureEnabledSetting, MemoryContainerHelper memoryContainerHelper, MemoryEmbeddingHelper memoryEmbeddingHelper, MemoryProcessingService memoryProcessingService, MemorySearchService memorySearchService, MemoryOperationsService memoryOperationsService) {
        super("cluster:admin/opensearch/ml/memory_containers/memories/add", transportService, actionFilters, MLAddMemoriesRequest::new);
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.memoryContainerHelper = memoryContainerHelper;
        this.memoryEmbeddingHelper = memoryEmbeddingHelper;
        this.memoryProcessingService = memoryProcessingService;
        this.memorySearchService = memorySearchService;
        this.memoryOperationsService = memoryOperationsService;
    }

    protected void doExecute(Task task, MLAddMemoriesRequest request, ActionListener<MLAddMemoriesResponse> actionListener) {
        if (!this.mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
            actionListener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        User user = RestActionUtils.getUserContext(this.client);
        MLAddMemoriesInput input = request.getMlAddMemoryInput();
        if (input == null) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Memory input is required"));
            return;
        }
        String memoryContainerId = input.getMemoryContainerId();
        if (StringUtils.isBlank((CharSequence)memoryContainerId)) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Memory container ID is required"));
            return;
        }
        this.memoryContainerHelper.getMemoryContainer(memoryContainerId, (ActionListener<MLMemoryContainer>)ActionListener.wrap(container -> {
            if (!this.memoryContainerHelper.checkMemoryContainerAccess(user, (MLMemoryContainer)container)) {
                actionListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have permissions to add memory to this container", RestStatus.FORBIDDEN, new Object[0]));
                return;
            }
            String indexName = this.memoryContainerHelper.getMemoryIndexName((MLMemoryContainer)container);
            if (indexName == null) {
                actionListener.onFailure((Exception)new IllegalStateException("Memory index not created for this container"));
                return;
            }
            this.processAndIndexMemory(input, (MLMemoryContainer)container, indexName, user, actionListener);
        }, arg_0 -> actionListener.onFailure(arg_0)));
    }

    private void processAndIndexMemory(MLAddMemoriesInput input, MLMemoryContainer container, String indexName, User user, ActionListener<MLAddMemoriesResponse> actionListener) {
        try {
            boolean hasLlmModel;
            Object sessionId;
            boolean userProvidedSessionId;
            List messages = input.getMessages();
            boolean bl = userProvidedSessionId = input.getSessionId() != null && !input.getSessionId().isEmpty();
            if (!userProvidedSessionId) {
                sessionId = "sess_" + UUID.randomUUID().toString();
                log.debug("Auto-generated session ID: {}", sessionId);
            } else {
                sessionId = input.getSessionId();
                log.debug("User provided session ID: {}", sessionId);
            }
            MemoryType memoryType = MemoryType.RAW_MESSAGE;
            Boolean infer = input.getInfer();
            MemoryStorageConfig storageConfig = container.getMemoryStorageConfig();
            boolean bl2 = hasLlmModel = storageConfig != null && storageConfig.getLlmModelId() != null;
            if (infer != null && infer.booleanValue() && !hasLlmModel) {
                actionListener.onFailure((Exception)new IllegalArgumentException("infer=true requires llm_model_id to be configured in memory storage"));
                return;
            }
            if (infer == null) {
                infer = hasLlmModel;
            }
            if (!infer.booleanValue()) {
                for (MessageInput message : messages) {
                    if (message.getRole() != null) continue;
                    actionListener.onFailure((Exception)new IllegalArgumentException("Role is required for all messages when infer=false"));
                    return;
                }
            }
            if (infer.booleanValue()) {
                this.processMessagesWithLLM(input, container, indexName, (String)sessionId, userProvidedSessionId, user, storageConfig, actionListener);
            } else {
                this.processMessagesWithoutLLM(input, container, indexName, (String)sessionId, user, storageConfig, actionListener);
            }
        }
        catch (Exception e) {
            log.error("Failed to add memory", (Throwable)e);
            actionListener.onFailure(e);
        }
    }

    private void processMessagesWithLLM(MLAddMemoriesInput input, MLMemoryContainer container, String indexName, String sessionId, boolean userProvidedSessionId, User user, MemoryStorageConfig storageConfig, ActionListener<MLAddMemoriesResponse> actionListener) {
        List messages = input.getMessages();
        log.debug("Processing {} messages for fact extraction", (Object)messages.size());
        this.memoryProcessingService.extractFactsFromConversation(messages, storageConfig, (ActionListener<List<String>>)ActionListener.wrap(facts -> this.storeMessagesAndFacts(input, container, indexName, messages, sessionId, userProvidedSessionId, user, (List<String>)facts, storageConfig, actionListener), e -> {
            log.error("Failed to extract facts with LLM", (Throwable)e);
            actionListener.onFailure((Exception)new OpenSearchException("Failed to extract facts: " + e.getMessage(), (Throwable)e, new Object[0]));
        }));
    }

    private void storeMessagesAndFacts(MLAddMemoriesInput input, MLMemoryContainer container, String indexName, List<MessageInput> messages, String sessionId, boolean userProvidedSessionId, User user, List<String> facts, MemoryStorageConfig storageConfig, ActionListener<MLAddMemoriesResponse> actionListener) {
        Instant now = Instant.now();
        ArrayList<IndexRequest> indexRequests = new ArrayList<IndexRequest>();
        ArrayList<MemoryInfo> memoryInfos = new ArrayList<MemoryInfo>();
        for (MessageInput message : messages) {
            MLMemory rawMemory = MLMemory.builder().sessionId(sessionId).memory(message.getContent()).memoryType(MemoryType.RAW_MESSAGE).userId(user != null ? user.getName() : null).agentId(input.getAgentId()).role(message.getRole() != null ? message.getRole() : "user").tags(input.getTags()).createdTime(now).lastUpdatedTime(now).build();
            IndexRequest request = new IndexRequest(indexName).source(rawMemory.toIndexMap());
            indexRequests.add(request);
            memoryInfos.add(new MemoryInfo(null, rawMemory.getMemory(), rawMemory.getMemoryType(), false));
        }
        if (!facts.isEmpty() && storageConfig != null && storageConfig.getLlmModelId() != null) {
            log.debug("Searching for similar facts in session to make memory decisions");
            this.memorySearchService.searchSimilarFactsForSession(facts, sessionId, indexName, storageConfig, (ActionListener<List<FactSearchResult>>)ActionListener.wrap(allSearchResults -> {
                log.debug("Found {} total similar facts across all {} new facts", (Object)allSearchResults.size(), (Object)facts.size());
                this.memoryProcessingService.makeMemoryDecisions(facts, (List<FactSearchResult>)allSearchResults, storageConfig, (ActionListener<List<MemoryDecision>>)ActionListener.wrap(decisions -> this.memoryOperationsService.executeMemoryOperations((List<MemoryDecision>)decisions, indexName, sessionId, user, input, storageConfig, (ActionListener<List<MemoryResult>>)ActionListener.wrap(operationResults -> {
                    ArrayList allResults = new ArrayList(operationResults);
                    MLAddMemoriesResponse response = MLAddMemoriesResponse.builder().results(allResults).sessionId(sessionId).build();
                    actionListener.onResponse((Object)response);
                }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))), e -> {
                    log.error("Failed to make memory decisions", (Throwable)e);
                    actionListener.onFailure((Exception)new OpenSearchException("Failed to make memory decisions: " + e.getMessage(), (Throwable)e, new Object[0]));
                }));
            }, e -> {
                log.error("Failed to search similar facts", (Throwable)e);
                actionListener.onFailure((Exception)new OpenSearchException("Failed to search similar facts: " + e.getMessage(), (Throwable)e, new Object[0]));
            }));
        } else {
            this.memoryOperationsService.createFactMemoriesFromList(facts, input, indexName, sessionId, user, now, indexRequests, memoryInfos);
            this.processEmbeddingsAndIndex(messages, facts, storageConfig, indexRequests, memoryInfos, sessionId, indexName, actionListener);
        }
    }

    private void processEmbeddingsAndIndex(List<MessageInput> messages, List<String> facts, MemoryStorageConfig storageConfig, List<IndexRequest> indexRequests, List<MemoryInfo> memoryInfos, String sessionId, String indexName, ActionListener<MLAddMemoriesResponse> actionListener) {
        boolean needsEmbedding;
        boolean bl = needsEmbedding = storageConfig != null && storageConfig.isSemanticStorageEnabled();
        if (needsEmbedding) {
            ArrayList<String> textsToEmbed = new ArrayList<String>();
            for (MessageInput message : messages) {
                textsToEmbed.add(message.getContent());
            }
            textsToEmbed.addAll(facts);
            this.memoryEmbeddingHelper.generateEmbeddingsForMultipleTexts(textsToEmbed, storageConfig, (ActionListener<List<Object>>)ActionListener.wrap(embeddings -> {
                if (embeddings != null && embeddings.size() == indexRequests.size()) {
                    for (int i = 0; i < indexRequests.size(); ++i) {
                        Map sourceMap = ((IndexRequest)indexRequests.get(i)).sourceAsMap();
                        sourceMap.put("memory_embedding", embeddings.get(i));
                        ((IndexRequest)indexRequests.get(i)).source(sourceMap);
                    }
                }
                this.memoryOperationsService.bulkIndexMemoriesWithResults(indexRequests, memoryInfos, sessionId, indexName, actionListener);
            }, e -> {
                log.error("Failed to generate embeddings for memories", (Throwable)e);
                actionListener.onFailure((Exception)new OpenSearchException("Failed to generate embeddings for memories: " + e.getMessage(), (Throwable)e, new Object[0]));
            }));
        } else {
            this.memoryOperationsService.bulkIndexMemoriesWithResults(indexRequests, memoryInfos, sessionId, indexName, actionListener);
        }
    }

    private void processMessagesWithoutLLM(MLAddMemoriesInput input, MLMemoryContainer container, String indexName, String sessionId, User user, MemoryStorageConfig storageConfig, ActionListener<MLAddMemoriesResponse> actionListener) {
        List messages = input.getMessages();
        Instant now = Instant.now();
        ArrayList<IndexRequest> indexRequests = new ArrayList<IndexRequest>();
        ArrayList<MemoryInfo> memoryInfos = new ArrayList<MemoryInfo>();
        for (MessageInput message : messages) {
            MLMemory memory = MLMemory.builder().sessionId(sessionId).memory(message.getContent()).memoryType(MemoryType.RAW_MESSAGE).userId(user != null ? user.getName() : null).agentId(input.getAgentId()).role(message.getRole()).tags(input.getTags()).createdTime(now).lastUpdatedTime(now).build();
            IndexRequest request = new IndexRequest(indexName).source(memory.toIndexMap());
            indexRequests.add(request);
            memoryInfos.add(new MemoryInfo(null, memory.getMemory(), memory.getMemoryType(), true));
        }
        if (storageConfig != null && storageConfig.isSemanticStorageEnabled()) {
            ArrayList<String> texts = new ArrayList<String>();
            for (MessageInput message : messages) {
                texts.add(message.getContent());
            }
            this.memoryEmbeddingHelper.generateEmbeddingsForMultipleTexts(texts, storageConfig, (ActionListener<List<Object>>)ActionListener.wrap(embeddings -> {
                if (embeddings != null && embeddings.size() == indexRequests.size()) {
                    for (int i = 0; i < indexRequests.size(); ++i) {
                        Map sourceMap = ((IndexRequest)indexRequests.get(i)).sourceAsMap();
                        sourceMap.put("memory_embedding", embeddings.get(i));
                        ((IndexRequest)indexRequests.get(i)).source(sourceMap);
                    }
                }
                this.memoryOperationsService.bulkIndexMemoriesWithResults(indexRequests, memoryInfos, sessionId, indexName, actionListener);
            }, e -> {
                log.error("Failed to generate embeddings, storing without", (Throwable)e);
                this.memoryOperationsService.bulkIndexMemoriesWithResults(indexRequests, memoryInfos, sessionId, indexName, actionListener);
            }));
        } else {
            this.memoryOperationsService.bulkIndexMemoriesWithResults(indexRequests, memoryInfos, sessionId, indexName, actionListener);
        }
    }
}

