/*
 * Decompiled with CFR 0.152.
 */
package org.joget.ai.agent.mcp;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.joget.ai.agent.mcp.JRPCRequestDto;
import org.joget.ai.agent.mcp.JsonRpcMethod;
import org.joget.ai.agent.mcp.McpClient;
import org.joget.ai.agent.mcp.McpException;
import org.joget.ai.agent.mcp.McpSchemaCache;
import org.joget.ai.agent.mcp.ToolSpecification;
import org.joget.commons.util.LogUtil;

public class CustomMcpClient
implements McpClient {
    private static final Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(30L);
    private static final Duration SESSION_ID_WAIT_TIMEOUT = Duration.ofSeconds(30L);
    private static final Duration SSE_RESPONSE_TIMEOUT = Duration.ofSeconds(30L);
    private final String baseUrl;
    private final Duration timeout;
    private final ObjectMapper objectMapper;
    private volatile boolean initialized = false;
    private volatile String sessionId = null;
    private final OkHttpClient okHttpClient;
    private final HttpClient httpClient;
    private static final Pattern SESSION_ID_PATTERN = Pattern.compile("(?:sessionId|session_id)=([^\\s&]+)");
    private final TransportMode mode;
    private EventSource eventSource;
    private final Map<String, CompletableFuture<JsonNode>> pendingRequests = new ConcurrentHashMap<String, CompletableFuture<JsonNode>>();

    public String getClassName() {
        return this.getClass().getName();
    }

    public CustomMcpClient(String sseUrl, Duration timeout) {
        this.baseUrl = sseUrl;
        this.timeout = timeout;
        this.objectMapper = new ObjectMapper();
        this.okHttpClient = this.createOkHttpClient(timeout);
        this.httpClient = HttpClient.newBuilder().connectTimeout(timeout).build();
        this.mode = sseUrl.endsWith("/mcp") ? TransportMode.STREAM_HTTP : TransportMode.LEGACY_SSE;
    }

    private OkHttpClient createOkHttpClient(Duration timeout) {
        OkHttpClient.Builder builder = new OkHttpClient.Builder().connectTimeout(timeout).readTimeout(DEFAULT_READ_TIMEOUT);
        return builder.build();
    }

    @Override
    public void initialize() throws McpException {
        if (this.initialized && this.sessionId != null) {
            return;
        }
        try {
            if (this.mode == TransportMode.STREAM_HTTP) {
                this.initializeHttpSession();
            } else {
                this.initializeSseSession();
            }
            if (this.sessionId == null || this.sessionId.trim().isEmpty()) {
                LogUtil.warn((String)this.getClassName(), (String)"Failed to obtain session ID from MCP server");
            }
            this.initialized = true;
        }
        catch (Exception e) {
            LogUtil.error((String)this.getClassName(), (Throwable)e, (String)"Failed to initialize MCP client");
        }
    }

    private void ensureInitialized() throws McpException {
        if (!this.initialized) {
            this.initialize();
        }
    }

    private void initializeSseSession() throws McpException {
        ConnectionState connectionState = new ConnectionState();
        Request sseRequest = this.createSseRequest();
        EventSourceListener listener = this.createEventSourceListener(connectionState);
        EventSource.Factory factory = EventSources.createFactory((OkHttpClient)this.okHttpClient);
        this.eventSource = factory.newEventSource(sseRequest, listener);
        this.waitForSessionId(connectionState);
    }

    private Request createSseRequest() {
        return new Request.Builder().url(this.baseUrl).header("Accept", "text/event-stream").header("Cache-Control", "no-cache").build();
    }

    private EventSourceListener createEventSourceListener(final ConnectionState connectionState) {
        return new EventSourceListener(){

            public void onEvent(EventSource eventSource, String id, String type, String data) {
                CustomMcpClient.this.handleSseEvent(connectionState, id, type, data);
            }

            public void onFailure(EventSource eventSource, Throwable t, Response response) {
                CustomMcpClient.this.handleSseFailure(connectionState, t);
            }

            public void onClosed(EventSource eventSource) {
                CustomMcpClient.this.handleSseClosure();
            }
        };
    }

    private void handleSseEvent(ConnectionState connectionState, String id, String type, String data) {
        String extractedSessionId;
        String eventInfo = "ID:" + id + " Type:" + type + " Data:" + data;
        connectionState.addEvent(eventInfo);
        if ("ping".equals(data)) {
            return;
        }
        if (connectionState.sessionId == null && (extractedSessionId = this.extractSessionId(data)) != null) {
            connectionState.setSessionId(extractedSessionId);
            this.sessionId = extractedSessionId;
            return;
        }
        if (connectionState.sessionId != null) {
            this.handlePotentialJsonRpcResponse(data);
        }
        if (!data.trim().isEmpty()) {
            String result = "SSE connected and received data, but no sessionId found in: " + String.join((CharSequence)", ", connectionState.receivedEvents);
            connectionState.updateResult(result);
        }
    }

    private void handleSseFailure(ConnectionState connectionState, Throwable t) {
        if (connectionState.sessionId == null) {
            McpException error = new McpException("SSE connection failed during initialization: " + t.getMessage(), t);
            connectionState.setError(error);
        }
        this.clearPendingRequestsWithError("SSE connection lost");
    }

    private void handleSseClosure() {
        this.clearPendingRequestsWithError("SSE connection closed");
    }

    private void clearPendingRequestsWithError(String errorMessage) {
        this.pendingRequests.values().forEach(future -> future.completeExceptionally(new McpException(errorMessage)));
        this.pendingRequests.clear();
    }

    private void waitForSessionId(ConnectionState connectionState) throws McpException {
        try {
            String receivedSessionId = connectionState.sessionIdFuture.get(SESSION_ID_WAIT_TIMEOUT.toSeconds(), TimeUnit.SECONDS);
            if (receivedSessionId == null || receivedSessionId.trim().isEmpty()) {
                throw new McpException("No sessionId found in SSE stream");
            }
        }
        catch (TimeoutException e) {
            throw new McpException("Timeout: No sessionId found");
        }
        catch (ExecutionException e) {
            Throwable cause = e.getCause();
            if (cause instanceof McpException) {
                throw (McpException)cause;
            }
            throw new McpException("Failed to obtain session ID", cause);
        }
        catch (Exception e) {
            if (connectionState.sessionId == null && this.eventSource != null) {
                this.eventSource.cancel();
            }
            throw new McpException("Unexpected error during SSE connection", e);
        }
    }

    private void handlePotentialJsonRpcResponse(String data) {
        try {
            String responseId;
            CompletableFuture<JsonNode> future;
            JsonNode jsonResponse = this.objectMapper.readTree(data);
            if (jsonResponse.has("jsonrpc") && jsonResponse.has("id") && (future = this.pendingRequests.remove(responseId = jsonResponse.get("id").asText())) != null) {
                future.complete(jsonResponse);
            }
        }
        catch (Exception e) {
            LogUtil.error((String)this.getClassName(), (Throwable)e, (String)"Error processing JSON-RPC response");
        }
    }

    private String extractSessionId(String data) {
        if (data == null || data.trim().isEmpty()) {
            return null;
        }
        Matcher matcher = SESSION_ID_PATTERN.matcher(data);
        if (matcher.find()) {
            try {
                return matcher.group(1).trim();
            }
            catch (Exception e) {
                LogUtil.error((String)this.getClassName(), (Throwable)e, (String)"Error extracting sessionId");
            }
        }
        try {
            String sessionId;
            JsonNode jsonData = this.objectMapper.readTree(data);
            if (jsonData.has("sessionId") && !(sessionId = jsonData.get("sessionId").asText().trim()).isEmpty()) {
                return sessionId;
            }
        }
        catch (Exception e) {
            LogUtil.error((String)this.getClassName(), (Throwable)e, (String)"Error extracting sessionId");
        }
        return null;
    }

    @Override
    public List<ToolSpecification> listTools() throws McpException {
        this.ensureInitialized();
        try {
            String requestId = "list-" + UUID.randomUUID();
            JRPCRequestDto requestDto = new JRPCRequestDto(requestId, JsonRpcMethod.TOOLS_LIST, null);
            JsonNode response = this.sendRpcRequest(requestDto.toJson(this.objectMapper));
            if (response == null || response.has("error")) {
                JsonNode error = response.get("error");
                String errorMsg = error.has("message") ? error.get("message").asText() : "Unknown error";
                throw new McpException("Failed to list tools: " + errorMsg);
            }
            ArrayList<ToolSpecification> tools = new ArrayList<ToolSpecification>();
            JsonNode result = response.get("result");
            if (result != null && result.has("tools")) {
                for (JsonNode toolNode : result.get("tools")) {
                    String name = toolNode.get("name").asText();
                    String description = toolNode.has("description") ? toolNode.get("description").asText() : "";
                    HashMap<String, ToolSpecification.ParameterSpec> parameters = new HashMap<String, ToolSpecification.ParameterSpec>();
                    ArrayList<String> required = new ArrayList<String>();
                    JsonNode inputSchema = toolNode.get("inputSchema");
                    if (inputSchema != null) {
                        JsonNode requiredArray;
                        JsonNode properties = inputSchema.get("properties");
                        if (properties != null) {
                            properties.fieldNames().forEachRemaining(fieldName -> {
                                JsonNode prop = properties.get(fieldName);
                                String type = prop.has("type") ? prop.get("type").asText() : "string";
                                String desc = prop.has("description") ? prop.get("description").asText() : "";
                                ArrayList<String> enumValues = new ArrayList<String>();
                                if (prop.has("enum") && prop.get("enum").isArray()) {
                                    for (JsonNode v : prop.get("enum")) {
                                        enumValues.add(v.asText());
                                    }
                                }
                                parameters.put((String)fieldName, new ToolSpecification.ParameterSpec(type, desc, enumValues));
                            });
                        }
                        if ((requiredArray = inputSchema.get("required")) != null && requiredArray.isArray()) {
                            requiredArray.forEach(req -> required.add(req.asText()));
                        }
                    }
                    tools.add(new ToolSpecification(name, description, parameters, required));
                }
            }
            McpSchemaCache.cacheTools(tools);
            return tools;
        }
        catch (Exception e) {
            if (e instanceof McpException) {
                throw (McpException)e;
            }
            throw new McpException("Failed to list tools", e);
        }
    }

    @Override
    public String executeTool(String toolName, Map<String, Object> arguments) throws McpException {
        this.ensureInitialized();
        try {
            JsonNode content;
            ObjectNode argsNode = this.objectMapper.createObjectNode();
            for (Map.Entry<String, Object> entry : arguments.entrySet()) {
                String stringValue;
                String paramName = entry.getKey();
                Object value = entry.getValue();
                List<String> enumValues = McpSchemaCache.getEnumValues(toolName, paramName);
                if (!enumValues.isEmpty() && !enumValues.contains(stringValue = String.valueOf(value))) {
                    LogUtil.warn((String)this.getClassName(), (String)("Invalid value '" + stringValue + "' for parameter '" + paramName));
                    continue;
                }
                argsNode.set(paramName, this.objectMapper.valueToTree(value));
            }
            String requestId = "exec-" + UUID.randomUUID();
            ObjectNode params = this.objectMapper.createObjectNode();
            params.put("name", toolName);
            params.set("arguments", (JsonNode)argsNode);
            JRPCRequestDto requestDto = new JRPCRequestDto(requestId, JsonRpcMethod.TOOLS_CALL, params);
            JsonNode response = this.sendRpcRequest(requestDto.toJson(this.objectMapper));
            if (response.has("error")) {
                throw new McpException("Tool execution failed: " + response.get("error"));
            }
            JsonNode result = response.get("result");
            if (result != null && result.has("content") && (content = result.get("content")) != null && !content.isEmpty() && content.isArray() && content.size() > 0 && content.get(0).has("text")) {
                return content.get(0).get("text").asText();
            }
            return result != null ? result.toString() : "No result";
        }
        catch (Exception e) {
            throw new McpException("Failed to execute tool: " + toolName, e);
        }
    }

    private JsonNode sendRpcRequest(ObjectNode request) throws Exception {
        if (this.mode == TransportMode.STREAM_HTTP) {
            return this.sendStreamHttpRpc(request);
        }
        return this.sendLegacyRpc(request);
    }

    private JsonNode sendLegacyRpc(ObjectNode request) throws Exception {
        String requestId = request.get("id").asText();
        try {
            String rpcUrl = this.buildRpcUrl();
            CompletableFuture<JsonNode> future = new CompletableFuture<JsonNode>();
            this.pendingRequests.put(requestId, future);
            HttpRequest httpRequest = HttpRequest.newBuilder().uri(URI.create(rpcUrl)).header("Content-Type", "application/json").timeout(this.timeout).POST(HttpRequest.BodyPublishers.ofString(request.toString())).build();
            HttpResponse<String> response = this.httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
            return this.handleHttpResponse(response, requestId, future);
        }
        catch (Exception e) {
            this.pendingRequests.remove(requestId);
            if (e instanceof McpException) {
                throw e;
            }
            throw new McpException("Failed to send RPC request for ID: " + requestId, e);
        }
    }

    private String buildRpcUrl() {
        String baseUrl = this.baseUrl;
        int trimmed = baseUrl.endsWith("/sse") ? baseUrl.length() - 4 : (baseUrl.endsWith("/mcp") ? baseUrl.length() - 4 : baseUrl.lastIndexOf("/"));
        baseUrl = baseUrl.substring(0, trimmed);
        return baseUrl + "/sse/message?sessionId=" + this.sessionId;
    }

    private JsonNode handleHttpResponse(HttpResponse<String> response, String requestId, CompletableFuture<JsonNode> future) throws McpException {
        if (response.statusCode() == 200) {
            this.pendingRequests.remove(requestId);
            try {
                return this.objectMapper.readTree(response.body());
            }
            catch (Exception e) {
                throw new McpException("Failed to parse response for request ID: " + requestId, e);
            }
        }
        if (response.statusCode() == 202) {
            return this.waitForSseResponse(requestId, future);
        }
        this.pendingRequests.remove(requestId);
        throw new McpException("HTTP error: " + response.statusCode() + " - " + response.body());
    }

    private JsonNode waitForSseResponse(String requestId, CompletableFuture<JsonNode> future) throws McpException {
        try {
            JsonNode sseResponse = future.get(SSE_RESPONSE_TIMEOUT.toSeconds(), TimeUnit.SECONDS);
            this.pendingRequests.remove(requestId);
            return sseResponse;
        }
        catch (TimeoutException e) {
            this.pendingRequests.remove(requestId);
            throw new McpException("Timeout waiting for response to request ID: " + requestId);
        }
        catch (ExecutionException e) {
            this.pendingRequests.remove(requestId);
            Throwable cause = e.getCause();
            throw new McpException("Failed to get response for request ID: " + requestId + ". SSE failed with: " + cause.getMessage());
        }
        catch (InterruptedException e) {
            this.pendingRequests.remove(requestId);
            Thread.currentThread().interrupt();
            throw new McpException("Interrupted waiting for response to request ID: " + requestId, e);
        }
    }

    public void initializeHttpSession() throws McpException {
        ObjectMapper objectMapper = new ObjectMapper();
        ObjectNode params = objectMapper.createObjectNode();
        params.put("protocolVersion", "2024-11-05");
        params.set("capabilities", (JsonNode)objectMapper.createObjectNode());
        ObjectNode clientInfo = objectMapper.createObjectNode();
        clientInfo.put("name", "java-mcp-client");
        clientInfo.put("version", "1.0");
        params.set("clientInfo", (JsonNode)clientInfo);
        String requestId = "init-" + UUID.randomUUID();
        JRPCRequestDto initRequest = new JRPCRequestDto(requestId, JsonRpcMethod.TOOLS_INIT, params);
        String requestJson = null;
        requestJson = initRequest.toJson(objectMapper).toString();
        HttpRequest httpRequest = HttpRequest.newBuilder().uri(URI.create(this.baseUrl)).header("Content-Type", "application/json").header("Accept", "application/json, text/event-stream").timeout(this.timeout).POST(HttpRequest.BodyPublishers.ofString(requestJson)).build();
        try {
            String body;
            HttpResponse<InputStream> response = this.httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofInputStream());
            if (response.statusCode() != 200) {
                InputStream err = response.body();
                try {
                    String errBody = new String(err.readAllBytes(), StandardCharsets.UTF_8);
                    throw new McpException("Init failed: HTTP " + response.statusCode() + " - " + errBody);
                }
                catch (Throwable errBody) {
                    if (err != null) {
                        try {
                            err.close();
                        }
                        catch (Throwable throwable) {
                            errBody.addSuppressed(throwable);
                        }
                    }
                    throw errBody;
                }
            }
            try (InputStream is = response.body();){
                body = new String(is.readAllBytes(), StandardCharsets.UTF_8);
            }
            Optional<String> sessionIdOpt = response.headers().firstValue("mcp-session-id").or(() -> response.headers().firstValue("Mcp-Session-Id"));
            if (sessionIdOpt.isPresent()) {
                this.sessionId = sessionIdOpt.get();
                LogUtil.debug((String)this.getClassName(), (String)("MCP Session ID obtained from headers: " + this.sessionId));
                return;
            }
            JsonNode json = objectMapper.readTree(body);
            if (!json.has("result") || !json.get("result").has("sessionId")) {
                throw new McpException("No MCP Session ID found in headers or body");
            }
            this.sessionId = json.get("result").get("sessionId").asText();
            LogUtil.debug((String)this.getClassName(), (String)("MCP Session ID obtained from body: " + this.sessionId));
        }
        catch (IOException | InterruptedException ex) {
            throw new McpException("Failed to initialize MCP session", ex);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private JsonNode sendStreamHttpRpc(ObjectNode request) throws Exception {
        HttpRequest httpRequest;
        HttpResponse<InputStream> response;
        String requestJson = this.objectMapper.writeValueAsString((Object)request);
        HttpRequest.Builder builder = HttpRequest.newBuilder().uri(URI.create(this.baseUrl)).header("Accept", "application/json, text/event-stream").header("Content-Type", "application/json, text/event-stream").timeout(this.timeout).POST(HttpRequest.BodyPublishers.ofString(requestJson));
        if (this.sessionId != null && !this.sessionId.trim().isEmpty()) {
            builder.header("Mcp-Session-Id", this.sessionId);
        }
        if ((response = this.httpClient.send(httpRequest = builder.build(), HttpResponse.BodyHandlers.ofInputStream())).statusCode() != 200) {
            throw new McpException("HTTP " + response.statusCode());
        }
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(response.body()));){
            String line;
            while ((line = reader.readLine()) != null) {
                String jsonPayload;
                if (line.trim().isEmpty()) continue;
                if (line.startsWith("data:")) {
                    jsonPayload = line.substring("data:".length()).trim();
                } else {
                    if (!line.startsWith("{") && !line.startsWith("[")) continue;
                    jsonPayload = line.trim();
                }
                try {
                    JsonNode json = this.objectMapper.readTree(jsonPayload);
                    if (!json.has("id") || !json.get("id").asText().equals(request.get("id").asText()) || !json.has("result") && !json.has("error")) continue;
                    JsonNode jsonNode = json;
                    return jsonNode;
                }
                catch (Exception e) {
                    LogUtil.debug((String)this.getClassName(), (String)("Skipping non-JSON line: " + line));
                }
            }
            throw new McpException("No response received for request " + request.get("id").asText());
        }
    }

    @Override
    public void close() throws McpException {
        if (!this.initialized) {
            return;
        }
        this.clearPendingRequestsWithError("Client shutting down");
        if (this.eventSource != null) {
            this.eventSource.cancel();
        }
        this.initialized = false;
        this.sessionId = null;
        if (this.okHttpClient != null) {
            this.okHttpClient.dispatcher().executorService().shutdown();
            this.okHttpClient.connectionPool().evictAll();
        }
    }

    private static class ConnectionState {
        volatile String sessionId = null;
        volatile String result = "SSE connected but no sessionId or data received";
        volatile Exception error = null;
        final List<String> receivedEvents = Collections.synchronizedList(new ArrayList());
        final CompletableFuture<String> sessionIdFuture = new CompletableFuture();

        private ConnectionState() {
        }

        void addEvent(String event) {
            this.receivedEvents.add(event);
        }

        void setSessionId(String sessionId) {
            this.sessionId = sessionId;
            this.sessionIdFuture.complete(sessionId);
        }

        void setError(Exception error) {
            this.error = error;
            this.sessionIdFuture.completeExceptionally(error);
        }

        void updateResult(String result) {
            this.result = result;
        }
    }

    private static enum TransportMode {
        LEGACY_SSE,
        STREAM_HTTP;

    }
}

