diff --git a/src/Managing.Api/Controllers/LlmController.cs b/src/Managing.Api/Controllers/LlmController.cs index 33ab0b18..4011acfc 100644 --- a/src/Managing.Api/Controllers/LlmController.cs +++ b/src/Managing.Api/Controllers/LlmController.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; namespace Managing.Api.Controllers; @@ -25,6 +26,7 @@ public class LlmController : BaseController private readonly ILogger _logger; private readonly IMemoryCache _cache; private readonly IHubContext _hubContext; + private readonly IServiceScopeFactory _serviceScopeFactory; public LlmController( ILlmService llmService, @@ -32,13 +34,15 @@ public class LlmController : BaseController IUserService userService, ILogger logger, IMemoryCache cache, - IHubContext hubContext) : base(userService) + IHubContext hubContext, + IServiceScopeFactory serviceScopeFactory) : base(userService) { _llmService = llmService; _mcpService = mcpService; _logger = logger; _cache = cache; _hubContext = hubContext; + _serviceScopeFactory = serviceScopeFactory; } /// @@ -92,28 +96,67 @@ public class LlmController : BaseController } // Process in background to avoid blocking the HTTP response + // Create a scope for the background task to access scoped services (DbContext, etc.) + var userId = user.Id; _ = Task.Run(async () => { + using var scope = _serviceScopeFactory.CreateScope(); try { - await ChatStreamInternal(request, user, request.ConnectionId); + // Resolve scoped services from the new scope + var llmService = scope.ServiceProvider.GetRequiredService(); + var mcpService = scope.ServiceProvider.GetRequiredService(); + var userService = scope.ServiceProvider.GetRequiredService(); + var cache = scope.ServiceProvider.GetRequiredService(); + var hubContext = scope.ServiceProvider.GetRequiredService>(); + var logger = scope.ServiceProvider.GetRequiredService>(); + + // Reload user from the scoped service to ensure we have a valid user object + var scopedUser = await userService.GetUserByIdAsync(userId); + if (scopedUser == null) + { + await hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate + { + Type = "error", + Message = "User not found", + Error = "Unable to authenticate user" + }); + return; + } + + await ChatStreamInternal(request, scopedUser, request.ConnectionId, llmService, mcpService, cache, hubContext, logger); } catch (Exception ex) { _logger.LogError(ex, "Error processing chat stream for connection {ConnectionId}", request.ConnectionId); - await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate + try { - Type = "error", - Message = $"Error processing chat: {ex.Message}", - Error = ex.Message - }); + await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate + { + Type = "error", + Message = $"Error processing chat: {ex.Message}", + Error = ex.Message + }); + } + catch (Exception hubEx) + { + _logger.LogError(hubEx, "Error sending error message to SignalR client"); + } } }); return Ok(new { Message = "Chat stream started", ConnectionId = request.ConnectionId }); } - private async Task ChatStreamInternal(LlmChatStreamRequest request, User user, string connectionId) + private async Task ChatStreamInternal( + LlmChatStreamRequest request, + User user, + string connectionId, + ILlmService llmService, + IMcpService mcpService, + IMemoryCache cache, + IHubContext hubContext, + ILogger logger) { // Convert to LlmChatRequest for service calls var chatRequest = new LlmChatRequest @@ -127,21 +170,21 @@ public class LlmController : BaseController Tools = request.Tools }; - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "Initializing conversation and loading available tools..." }); // Get available MCP tools (with caching for 5 minutes) - var availableTools = await _cache.GetOrCreateAsync("mcp_tools", async entry => + var availableTools = await cache.GetOrCreateAsync("mcp_tools", async entry => { entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5); - return (await _mcpService.GetAvailableToolsAsync()).ToList(); + return (await mcpService.GetAvailableToolsAsync()).ToList(); }); chatRequest.Tools = availableTools; - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = $"Loaded {availableTools.Count} available tools. Preparing system context..." @@ -163,7 +206,7 @@ public class LlmController : BaseController chatRequest.Messages.Insert(0, systemMessage); // Proactively inject backtest details fetching if user is asking for analysis - await InjectBacktestDetailsFetchingIfNeeded(chatRequest, user); + await InjectBacktestDetailsFetchingIfNeeded(chatRequest, user, mcpService, logger); // Add helpful context extraction message if backtest ID was found AddBacktestContextGuidance(chatRequest); @@ -174,7 +217,7 @@ public class LlmController : BaseController LlmChatResponse? finalResponse = null; const int DelayBetweenIterationsMs = 500; - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..." @@ -184,7 +227,7 @@ public class LlmController : BaseController { iteration++; - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "iteration_start", Message = "Analyzing your request and determining next steps...", @@ -192,13 +235,13 @@ public class LlmController : BaseController MaxIterations = maxIterations }); - _logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}", + logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}", iteration, maxIterations, user.Id); // Add delay between iterations to avoid rapid bursts and rate limiting if (iteration > 1) { - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "Waiting briefly to respect rate limits...", @@ -212,7 +255,7 @@ public class LlmController : BaseController TrimConversationContext(chatRequest); // Send chat request to LLM - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "Sending request to LLM...", @@ -220,16 +263,16 @@ public class LlmController : BaseController MaxIterations = maxIterations }); - var response = await _llmService.ChatAsync(user, chatRequest); + var response = await llmService.ChatAsync(user, chatRequest); // If LLM doesn't want to call tools, we have our final answer if (!response.RequiresToolExecution || response.ToolCalls == null || !response.ToolCalls.Any()) { finalResponse = response; - _logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}", + logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}", iteration, user.Id); - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "Received final response. Preparing answer...", @@ -241,10 +284,10 @@ public class LlmController : BaseController } // LLM wants to call tools - execute them - _logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}", + logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}", response.ToolCalls.Count, iteration, user.Id); - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...", @@ -256,7 +299,7 @@ public class LlmController : BaseController var toolResults = new List(); foreach (var toolCall in response.ToolCalls) { - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "tool_call", Message = $"Calling tool: {toolCall.Name}", @@ -266,14 +309,14 @@ public class LlmController : BaseController ToolArguments = toolCall.Arguments }); - var (success, result, error) = await ExecuteToolSafely(user, toolCall.Name, toolCall.Arguments, toolCall.Id, iteration, maxIterations); + var (success, result, error) = await ExecuteToolSafely(user, toolCall.Name, toolCall.Arguments, toolCall.Id, iteration, maxIterations, mcpService, logger); if (success && result != null) { - _logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}", + logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}", toolCall.Name, iteration, user.Id); - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "tool_result", Message = $"Tool {toolCall.Name} completed successfully", @@ -291,10 +334,10 @@ public class LlmController : BaseController } else { - _logger.LogError("Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}", + logger.LogError("Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}", toolCall.Name, iteration, user.Id, error); - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "tool_result", Message = $"Tool {toolCall.Name} encountered an error: {error}", @@ -313,7 +356,7 @@ public class LlmController : BaseController } } - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "All tools completed. Analyzing results...", @@ -338,10 +381,10 @@ public class LlmController : BaseController // If we hit max iterations, return the last response (even if it has tool calls) if (finalResponse == null) { - _logger.LogWarning("Reached max iterations ({MaxIterations}) for user {UserId}. Returning last response.", + logger.LogWarning("Reached max iterations ({MaxIterations}) for user {UserId}. Returning last response.", maxIterations, user.Id); - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "thinking", Message = "Reached maximum iterations. Getting final response...", @@ -349,11 +392,11 @@ public class LlmController : BaseController MaxIterations = maxIterations }); - finalResponse = await _llmService.ChatAsync(user, chatRequest); + finalResponse = await llmService.ChatAsync(user, chatRequest); } // Send final response - await SendProgressUpdate(connectionId, new LlmProgressUpdate + await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate { Type = "final_response", Message = "Analysis complete!", @@ -413,7 +456,7 @@ public class LlmController : BaseController request.Messages.Insert(0, systemMessage); // Proactively inject backtest details fetching if user is asking for analysis - await InjectBacktestDetailsFetchingIfNeeded(request, user); + await InjectBacktestDetailsFetchingIfNeeded(request, user, _mcpService, _logger); // Add helpful context extraction message if backtest ID was found AddBacktestContextGuidance(request); @@ -613,6 +656,12 @@ public class LlmController : BaseController * "Recent backtests" → get_backtests_paginated(sortOrder='desc', pageSize=20) * "Bundle backtest analysis" → analyze_bundle_backtest(bundleRequestId='X') + ERROR HANDLING: + - If a tool returns a database connection error, wait a moment and retry once (these are often transient) + - If retry fails, explain the issue clearly to the user and suggest they try again later + - Never give up after a single error - always try at least once more for connection-related issues + - Distinguish between connection errors (retry) and data errors (no retry needed) + CRITICAL ANALYSIS WORKFLOW (APPLIES TO ALL DATA): 1. RETRIEVE COMPLETE DATA: - When asked to analyze ANY entity, ALWAYS fetch FULL details first (never rely on summary/paginated data alone) @@ -671,16 +720,18 @@ public class LlmController : BaseController Dictionary arguments, string toolCallId, int iteration, - int maxIterations) + int maxIterations, + IMcpService mcpService, + ILogger logger) { try { - var result = await _mcpService.ExecuteToolAsync(user, toolName, arguments); + var result = await mcpService.ExecuteToolAsync(user, toolName, arguments); return (true, result, null); } catch (Exception ex) { - _logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}", + logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}", toolName, iteration, user.Id); return (false, null, ex.Message); } @@ -722,7 +773,7 @@ public class LlmController : BaseController /// Proactively injects backtest details fetching when user asks for backtest analysis. /// Extracts backtest IDs from message content and automatically calls get_backtest_by_id. /// - private async Task InjectBacktestDetailsFetchingIfNeeded(LlmChatRequest request, User user) + private async Task InjectBacktestDetailsFetchingIfNeeded(LlmChatRequest request, User user, IMcpService mcpService, ILogger logger) { var lastUserMessage = request.Messages.LastOrDefault(m => m.Role == "user"); if (lastUserMessage == null || string.IsNullOrWhiteSpace(lastUserMessage.Content)) @@ -745,22 +796,22 @@ public class LlmController : BaseController var backtestId = ExtractBacktestIdFromConversation(request.Messages); if (string.IsNullOrEmpty(backtestId)) { - _logger.LogInformation("User requested backtest analysis but no backtest ID found in conversation context"); + logger.LogInformation("User requested backtest analysis but no backtest ID found in conversation context"); return; } - _logger.LogInformation("Proactively fetching backtest details for ID: {BacktestId}", backtestId); + logger.LogInformation("Proactively fetching backtest details for ID: {BacktestId}", backtestId); try { // Execute get_backtest_by_id tool to fetch complete details - var backtestDetails = await _mcpService.ExecuteToolAsync( + var backtestDetails = await mcpService.ExecuteToolAsync( user, "get_backtest_by_id", new Dictionary { ["id"] = backtestId } ); - _logger.LogInformation("Successfully fetched backtest details for ID: {BacktestId}. Result type: {ResultType}", + logger.LogInformation("Successfully fetched backtest details for ID: {BacktestId}. Result type: {ResultType}", backtestId, backtestDetails?.GetType().Name ?? "null"); // Inject the backtest details as a tool result in the conversation @@ -784,7 +835,7 @@ public class LlmController : BaseController // Add tool result message var serializedResult = JsonSerializer.Serialize(backtestDetails); - _logger.LogInformation("Serialized backtest details length: {Length} characters", serializedResult.Length); + logger.LogInformation("Serialized backtest details length: {Length} characters", serializedResult.Length); request.Messages.Add(new LlmMessage { @@ -793,11 +844,11 @@ public class LlmController : BaseController ToolCallId = toolCallId }); - _logger.LogInformation("Successfully injected backtest details into conversation for ID: {BacktestId}", backtestId); + logger.LogInformation("Successfully injected backtest details into conversation for ID: {BacktestId}", backtestId); } catch (Exception ex) { - _logger.LogError(ex, "Error fetching backtest details for ID: {BacktestId}", backtestId); + logger.LogError(ex, "Error fetching backtest details for ID: {BacktestId}", backtestId); // Inject an error message so LLM knows what happened var toolCallId = Guid.NewGuid().ToString(); @@ -862,7 +913,8 @@ public class LlmController : BaseController /// private string? ExtractBacktestIdFromConversation(List messages) { - _logger.LogDebug("Extracting backtest ID from {Count} messages", messages.Count); + // Note: This method doesn't use logger to avoid passing it through multiple call sites + // Logging is optional here as it's called from various contexts // Look through messages in reverse order (most recent first) for (int i = messages.Count - 1; i >= 0; i--) @@ -872,8 +924,7 @@ public class LlmController : BaseController continue; var content = message.Content; - _logger.LogDebug("Checking message {Index} (Role: {Role}): {Preview}", - i, message.Role, content.Length > 100 ? content.Substring(0, 100) + "..." : content); + // Debug logging removed to avoid logger dependency in this helper method // Try to extract from JSON in tool results (most reliable) if (message.Role == "tool") @@ -927,7 +978,6 @@ public class LlmController : BaseController if (match.Success) { var extractedId = match.Groups[1].Value; - _logger.LogInformation("Extracted backtest ID from text pattern: {BacktestId}", extractedId); return extractedId; } @@ -939,27 +989,25 @@ public class LlmController : BaseController if (guidMatch.Success && i >= messages.Count - 5) // Only use standalone GUIDs from recent messages { var extractedId = guidMatch.Groups[1].Value; - _logger.LogInformation("Extracted backtest ID from standalone GUID: {BacktestId}", extractedId); return extractedId; } } - _logger.LogWarning("No backtest ID found in conversation messages"); return null; } /// /// Helper method to send progress update via SignalR /// - private async Task SendProgressUpdate(string connectionId, LlmProgressUpdate update) + private async Task SendProgressUpdate(string connectionId, IHubContext hubContext, ILogger logger, LlmProgressUpdate update) { try { - await _hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update); + await hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update); } catch (Exception ex) { - _logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId); + logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId); } } } diff --git a/src/Managing.Mcp/Tools/BacktestTools.cs b/src/Managing.Mcp/Tools/BacktestTools.cs index 506dab61..5e0a6bd3 100644 --- a/src/Managing.Mcp/Tools/BacktestTools.cs +++ b/src/Managing.Mcp/Tools/BacktestTools.cs @@ -140,7 +140,8 @@ public class BacktestTools catch (Exception ex) { _logger.LogError(ex, "Error getting paginated backtests for user {UserId}", user.Id); - throw new InvalidOperationException($"Failed to retrieve backtests: {ex.Message}", ex); + var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve backtests"); + throw new InvalidOperationException(errorMessage, ex); } } @@ -182,7 +183,8 @@ public class BacktestTools catch (Exception ex) { _logger.LogError(ex, "Error getting backtest {BacktestId} for user {UserId}", id, user.Id); - throw new InvalidOperationException($"Failed to retrieve backtest: {ex.Message}", ex); + var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve backtest"); + throw new InvalidOperationException(errorMessage, ex); } } @@ -325,7 +327,8 @@ public class BacktestTools catch (Exception ex) { _logger.LogError(ex, "Error getting paginated bundle backtests for user {UserId}", user.Id); - throw new InvalidOperationException($"Failed to retrieve bundle backtests: {ex.Message}", ex); + var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve bundle backtests"); + throw new InvalidOperationException(errorMessage, ex); } } @@ -820,7 +823,80 @@ public class BacktestTools catch (Exception ex) { _logger.LogError(ex, "Error analyzing bundle backtest {BundleId} for user {UserId}", bundleRequestId, user.Id); - throw new InvalidOperationException($"Failed to analyze bundle backtest: {ex.Message}", ex); + var errorMessage = GetUserFriendlyErrorMessage(ex, "analyze bundle backtest"); + throw new InvalidOperationException(errorMessage, ex); } } + + /// + /// Gets a user-friendly error message based on the exception type + /// Uses reflection to check for database-specific exceptions without adding dependencies + /// + private static string GetUserFriendlyErrorMessage(Exception ex, string operation) + { + var exType = ex.GetType(); + var exTypeName = exType.Name; + var exMessage = ex.Message ?? string.Empty; + + // Check for database connection errors using type name matching (avoids dependency on Npgsql) + if (exTypeName.Contains("NpgsqlException", StringComparison.OrdinalIgnoreCase)) + { + // Try to get SqlState property via reflection + var sqlStateProp = exType.GetProperty("SqlState"); + var sqlState = sqlStateProp?.GetValue(ex)?.ToString(); + + // Check for connection-related errors + if (sqlState == "08000" || // Connection exception + sqlState == "08003" || // Connection does not exist + sqlState == "08006" || // Connection failure + exMessage.Contains("connection", StringComparison.OrdinalIgnoreCase) || + exMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase) || + exMessage.Contains("network", StringComparison.OrdinalIgnoreCase)) + { + return $"Unable to connect to the database to {operation}. Please try again in a moment. If the problem persists, contact support."; + } + + // Check for authentication errors + if (sqlState == "28P01") // Invalid password + { + return $"Database authentication failed while trying to {operation}. Please contact support."; + } + + // Generic database error + return $"A database error occurred while trying to {operation}. Please try again. If the problem persists, contact support."; + } + + // Check for Entity Framework database update exceptions + if (exTypeName.Contains("DbUpdateException", StringComparison.OrdinalIgnoreCase) && ex.InnerException != null) + { + return GetUserFriendlyErrorMessage(ex.InnerException, operation); + } + + // Check for timeout exceptions + if (ex is TimeoutException || exMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase)) + { + return $"The request timed out while trying to {operation}. Please try again."; + } + + // Check for connection-related messages in the exception chain + var innerEx = ex.InnerException; + while (innerEx != null) + { + var innerTypeName = innerEx.GetType().Name; + if (innerTypeName.Contains("NpgsqlException", StringComparison.OrdinalIgnoreCase)) + { + return GetUserFriendlyErrorMessage(innerEx, operation); + } + var innerMessage = innerEx.Message ?? string.Empty; + if (innerMessage.Contains("connection", StringComparison.OrdinalIgnoreCase) || + innerMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase)) + { + return $"Unable to connect to the database to {operation}. Please try again in a moment."; + } + innerEx = innerEx.InnerException; + } + + // Generic error message + return $"Failed to {operation}. {exMessage}"; + } } \ No newline at end of file