Refactor LlmController for improved service scope management and error handling
- Introduced IServiceScopeFactory to create a scope for background tasks, allowing access to scoped services like ILlmService and IMcpService. - Enhanced error handling during chat stream processing, providing user-friendly error messages for database connection issues. - Refactored SendProgressUpdate method to accept hubContext and logger as parameters, improving logging consistency. - Updated InjectBacktestDetailsFetchingIfNeeded method to utilize scoped services, ensuring accurate backtest detail fetching. - Improved overall error messaging and logging throughout the LlmController for better user feedback during chat interactions.
This commit is contained in:
@@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Authorization;
|
|||||||
using Microsoft.AspNetCore.Mvc;
|
using Microsoft.AspNetCore.Mvc;
|
||||||
using Microsoft.AspNetCore.SignalR;
|
using Microsoft.AspNetCore.SignalR;
|
||||||
using Microsoft.Extensions.Caching.Memory;
|
using Microsoft.Extensions.Caching.Memory;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
|
||||||
namespace Managing.Api.Controllers;
|
namespace Managing.Api.Controllers;
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ public class LlmController : BaseController
|
|||||||
private readonly ILogger<LlmController> _logger;
|
private readonly ILogger<LlmController> _logger;
|
||||||
private readonly IMemoryCache _cache;
|
private readonly IMemoryCache _cache;
|
||||||
private readonly IHubContext<LlmHub> _hubContext;
|
private readonly IHubContext<LlmHub> _hubContext;
|
||||||
|
private readonly IServiceScopeFactory _serviceScopeFactory;
|
||||||
|
|
||||||
public LlmController(
|
public LlmController(
|
||||||
ILlmService llmService,
|
ILlmService llmService,
|
||||||
@@ -32,13 +34,15 @@ public class LlmController : BaseController
|
|||||||
IUserService userService,
|
IUserService userService,
|
||||||
ILogger<LlmController> logger,
|
ILogger<LlmController> logger,
|
||||||
IMemoryCache cache,
|
IMemoryCache cache,
|
||||||
IHubContext<LlmHub> hubContext) : base(userService)
|
IHubContext<LlmHub> hubContext,
|
||||||
|
IServiceScopeFactory serviceScopeFactory) : base(userService)
|
||||||
{
|
{
|
||||||
_llmService = llmService;
|
_llmService = llmService;
|
||||||
_mcpService = mcpService;
|
_mcpService = mcpService;
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
_cache = cache;
|
_cache = cache;
|
||||||
_hubContext = hubContext;
|
_hubContext = hubContext;
|
||||||
|
_serviceScopeFactory = serviceScopeFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -92,15 +96,41 @@ public class LlmController : BaseController
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process in background to avoid blocking the HTTP response
|
// Process in background to avoid blocking the HTTP response
|
||||||
|
// Create a scope for the background task to access scoped services (DbContext, etc.)
|
||||||
|
var userId = user.Id;
|
||||||
_ = Task.Run(async () =>
|
_ = Task.Run(async () =>
|
||||||
{
|
{
|
||||||
|
using var scope = _serviceScopeFactory.CreateScope();
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await ChatStreamInternal(request, user, request.ConnectionId);
|
// Resolve scoped services from the new scope
|
||||||
|
var llmService = scope.ServiceProvider.GetRequiredService<ILlmService>();
|
||||||
|
var mcpService = scope.ServiceProvider.GetRequiredService<IMcpService>();
|
||||||
|
var userService = scope.ServiceProvider.GetRequiredService<IUserService>();
|
||||||
|
var cache = scope.ServiceProvider.GetRequiredService<IMemoryCache>();
|
||||||
|
var hubContext = scope.ServiceProvider.GetRequiredService<IHubContext<LlmHub>>();
|
||||||
|
var logger = scope.ServiceProvider.GetRequiredService<ILogger<LlmController>>();
|
||||||
|
|
||||||
|
// Reload user from the scoped service to ensure we have a valid user object
|
||||||
|
var scopedUser = await userService.GetUserByIdAsync(userId);
|
||||||
|
if (scopedUser == null)
|
||||||
|
{
|
||||||
|
await hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate
|
||||||
|
{
|
||||||
|
Type = "error",
|
||||||
|
Message = "User not found",
|
||||||
|
Error = "Unable to authenticate user"
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await ChatStreamInternal(request, scopedUser, request.ConnectionId, llmService, mcpService, cache, hubContext, logger);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error processing chat stream for connection {ConnectionId}", request.ConnectionId);
|
_logger.LogError(ex, "Error processing chat stream for connection {ConnectionId}", request.ConnectionId);
|
||||||
|
try
|
||||||
|
{
|
||||||
await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate
|
await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "error",
|
Type = "error",
|
||||||
@@ -108,12 +138,25 @@ public class LlmController : BaseController
|
|||||||
Error = ex.Message
|
Error = ex.Message
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
catch (Exception hubEx)
|
||||||
|
{
|
||||||
|
_logger.LogError(hubEx, "Error sending error message to SignalR client");
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return Ok(new { Message = "Chat stream started", ConnectionId = request.ConnectionId });
|
return Ok(new { Message = "Chat stream started", ConnectionId = request.ConnectionId });
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ChatStreamInternal(LlmChatStreamRequest request, User user, string connectionId)
|
private async Task ChatStreamInternal(
|
||||||
|
LlmChatStreamRequest request,
|
||||||
|
User user,
|
||||||
|
string connectionId,
|
||||||
|
ILlmService llmService,
|
||||||
|
IMcpService mcpService,
|
||||||
|
IMemoryCache cache,
|
||||||
|
IHubContext<LlmHub> hubContext,
|
||||||
|
ILogger<LlmController> logger)
|
||||||
{
|
{
|
||||||
// Convert to LlmChatRequest for service calls
|
// Convert to LlmChatRequest for service calls
|
||||||
var chatRequest = new LlmChatRequest
|
var chatRequest = new LlmChatRequest
|
||||||
@@ -127,21 +170,21 @@ public class LlmController : BaseController
|
|||||||
Tools = request.Tools
|
Tools = request.Tools
|
||||||
};
|
};
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "Initializing conversation and loading available tools..."
|
Message = "Initializing conversation and loading available tools..."
|
||||||
});
|
});
|
||||||
|
|
||||||
// Get available MCP tools (with caching for 5 minutes)
|
// Get available MCP tools (with caching for 5 minutes)
|
||||||
var availableTools = await _cache.GetOrCreateAsync("mcp_tools", async entry =>
|
var availableTools = await cache.GetOrCreateAsync("mcp_tools", async entry =>
|
||||||
{
|
{
|
||||||
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
|
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
|
||||||
return (await _mcpService.GetAvailableToolsAsync()).ToList();
|
return (await mcpService.GetAvailableToolsAsync()).ToList();
|
||||||
});
|
});
|
||||||
chatRequest.Tools = availableTools;
|
chatRequest.Tools = availableTools;
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = $"Loaded {availableTools.Count} available tools. Preparing system context..."
|
Message = $"Loaded {availableTools.Count} available tools. Preparing system context..."
|
||||||
@@ -163,7 +206,7 @@ public class LlmController : BaseController
|
|||||||
chatRequest.Messages.Insert(0, systemMessage);
|
chatRequest.Messages.Insert(0, systemMessage);
|
||||||
|
|
||||||
// Proactively inject backtest details fetching if user is asking for analysis
|
// Proactively inject backtest details fetching if user is asking for analysis
|
||||||
await InjectBacktestDetailsFetchingIfNeeded(chatRequest, user);
|
await InjectBacktestDetailsFetchingIfNeeded(chatRequest, user, mcpService, logger);
|
||||||
|
|
||||||
// Add helpful context extraction message if backtest ID was found
|
// Add helpful context extraction message if backtest ID was found
|
||||||
AddBacktestContextGuidance(chatRequest);
|
AddBacktestContextGuidance(chatRequest);
|
||||||
@@ -174,7 +217,7 @@ public class LlmController : BaseController
|
|||||||
LlmChatResponse? finalResponse = null;
|
LlmChatResponse? finalResponse = null;
|
||||||
const int DelayBetweenIterationsMs = 500;
|
const int DelayBetweenIterationsMs = 500;
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..."
|
Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..."
|
||||||
@@ -184,7 +227,7 @@ public class LlmController : BaseController
|
|||||||
{
|
{
|
||||||
iteration++;
|
iteration++;
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "iteration_start",
|
Type = "iteration_start",
|
||||||
Message = "Analyzing your request and determining next steps...",
|
Message = "Analyzing your request and determining next steps...",
|
||||||
@@ -192,13 +235,13 @@ public class LlmController : BaseController
|
|||||||
MaxIterations = maxIterations
|
MaxIterations = maxIterations
|
||||||
});
|
});
|
||||||
|
|
||||||
_logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}",
|
logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}",
|
||||||
iteration, maxIterations, user.Id);
|
iteration, maxIterations, user.Id);
|
||||||
|
|
||||||
// Add delay between iterations to avoid rapid bursts and rate limiting
|
// Add delay between iterations to avoid rapid bursts and rate limiting
|
||||||
if (iteration > 1)
|
if (iteration > 1)
|
||||||
{
|
{
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "Waiting briefly to respect rate limits...",
|
Message = "Waiting briefly to respect rate limits...",
|
||||||
@@ -212,7 +255,7 @@ public class LlmController : BaseController
|
|||||||
TrimConversationContext(chatRequest);
|
TrimConversationContext(chatRequest);
|
||||||
|
|
||||||
// Send chat request to LLM
|
// Send chat request to LLM
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "Sending request to LLM...",
|
Message = "Sending request to LLM...",
|
||||||
@@ -220,16 +263,16 @@ public class LlmController : BaseController
|
|||||||
MaxIterations = maxIterations
|
MaxIterations = maxIterations
|
||||||
});
|
});
|
||||||
|
|
||||||
var response = await _llmService.ChatAsync(user, chatRequest);
|
var response = await llmService.ChatAsync(user, chatRequest);
|
||||||
|
|
||||||
// If LLM doesn't want to call tools, we have our final answer
|
// If LLM doesn't want to call tools, we have our final answer
|
||||||
if (!response.RequiresToolExecution || response.ToolCalls == null || !response.ToolCalls.Any())
|
if (!response.RequiresToolExecution || response.ToolCalls == null || !response.ToolCalls.Any())
|
||||||
{
|
{
|
||||||
finalResponse = response;
|
finalResponse = response;
|
||||||
_logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}",
|
logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}",
|
||||||
iteration, user.Id);
|
iteration, user.Id);
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "Received final response. Preparing answer...",
|
Message = "Received final response. Preparing answer...",
|
||||||
@@ -241,10 +284,10 @@ public class LlmController : BaseController
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LLM wants to call tools - execute them
|
// LLM wants to call tools - execute them
|
||||||
_logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
|
logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
|
||||||
response.ToolCalls.Count, iteration, user.Id);
|
response.ToolCalls.Count, iteration, user.Id);
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...",
|
Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...",
|
||||||
@@ -256,7 +299,7 @@ public class LlmController : BaseController
|
|||||||
var toolResults = new List<LlmMessage>();
|
var toolResults = new List<LlmMessage>();
|
||||||
foreach (var toolCall in response.ToolCalls)
|
foreach (var toolCall in response.ToolCalls)
|
||||||
{
|
{
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "tool_call",
|
Type = "tool_call",
|
||||||
Message = $"Calling tool: {toolCall.Name}",
|
Message = $"Calling tool: {toolCall.Name}",
|
||||||
@@ -266,14 +309,14 @@ public class LlmController : BaseController
|
|||||||
ToolArguments = toolCall.Arguments
|
ToolArguments = toolCall.Arguments
|
||||||
});
|
});
|
||||||
|
|
||||||
var (success, result, error) = await ExecuteToolSafely(user, toolCall.Name, toolCall.Arguments, toolCall.Id, iteration, maxIterations);
|
var (success, result, error) = await ExecuteToolSafely(user, toolCall.Name, toolCall.Arguments, toolCall.Id, iteration, maxIterations, mcpService, logger);
|
||||||
|
|
||||||
if (success && result != null)
|
if (success && result != null)
|
||||||
{
|
{
|
||||||
_logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
|
logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
|
||||||
toolCall.Name, iteration, user.Id);
|
toolCall.Name, iteration, user.Id);
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "tool_result",
|
Type = "tool_result",
|
||||||
Message = $"Tool {toolCall.Name} completed successfully",
|
Message = $"Tool {toolCall.Name} completed successfully",
|
||||||
@@ -291,10 +334,10 @@ public class LlmController : BaseController
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
_logger.LogError("Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}",
|
logger.LogError("Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}",
|
||||||
toolCall.Name, iteration, user.Id, error);
|
toolCall.Name, iteration, user.Id, error);
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "tool_result",
|
Type = "tool_result",
|
||||||
Message = $"Tool {toolCall.Name} encountered an error: {error}",
|
Message = $"Tool {toolCall.Name} encountered an error: {error}",
|
||||||
@@ -313,7 +356,7 @@ public class LlmController : BaseController
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "All tools completed. Analyzing results...",
|
Message = "All tools completed. Analyzing results...",
|
||||||
@@ -338,10 +381,10 @@ public class LlmController : BaseController
|
|||||||
// If we hit max iterations, return the last response (even if it has tool calls)
|
// If we hit max iterations, return the last response (even if it has tool calls)
|
||||||
if (finalResponse == null)
|
if (finalResponse == null)
|
||||||
{
|
{
|
||||||
_logger.LogWarning("Reached max iterations ({MaxIterations}) for user {UserId}. Returning last response.",
|
logger.LogWarning("Reached max iterations ({MaxIterations}) for user {UserId}. Returning last response.",
|
||||||
maxIterations, user.Id);
|
maxIterations, user.Id);
|
||||||
|
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "thinking",
|
Type = "thinking",
|
||||||
Message = "Reached maximum iterations. Getting final response...",
|
Message = "Reached maximum iterations. Getting final response...",
|
||||||
@@ -349,11 +392,11 @@ public class LlmController : BaseController
|
|||||||
MaxIterations = maxIterations
|
MaxIterations = maxIterations
|
||||||
});
|
});
|
||||||
|
|
||||||
finalResponse = await _llmService.ChatAsync(user, chatRequest);
|
finalResponse = await llmService.ChatAsync(user, chatRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send final response
|
// Send final response
|
||||||
await SendProgressUpdate(connectionId, new LlmProgressUpdate
|
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
|
||||||
{
|
{
|
||||||
Type = "final_response",
|
Type = "final_response",
|
||||||
Message = "Analysis complete!",
|
Message = "Analysis complete!",
|
||||||
@@ -413,7 +456,7 @@ public class LlmController : BaseController
|
|||||||
request.Messages.Insert(0, systemMessage);
|
request.Messages.Insert(0, systemMessage);
|
||||||
|
|
||||||
// Proactively inject backtest details fetching if user is asking for analysis
|
// Proactively inject backtest details fetching if user is asking for analysis
|
||||||
await InjectBacktestDetailsFetchingIfNeeded(request, user);
|
await InjectBacktestDetailsFetchingIfNeeded(request, user, _mcpService, _logger);
|
||||||
|
|
||||||
// Add helpful context extraction message if backtest ID was found
|
// Add helpful context extraction message if backtest ID was found
|
||||||
AddBacktestContextGuidance(request);
|
AddBacktestContextGuidance(request);
|
||||||
@@ -613,6 +656,12 @@ public class LlmController : BaseController
|
|||||||
* "Recent backtests" → get_backtests_paginated(sortOrder='desc', pageSize=20)
|
* "Recent backtests" → get_backtests_paginated(sortOrder='desc', pageSize=20)
|
||||||
* "Bundle backtest analysis" → analyze_bundle_backtest(bundleRequestId='X')
|
* "Bundle backtest analysis" → analyze_bundle_backtest(bundleRequestId='X')
|
||||||
|
|
||||||
|
ERROR HANDLING:
|
||||||
|
- If a tool returns a database connection error, wait a moment and retry once (these are often transient)
|
||||||
|
- If retry fails, explain the issue clearly to the user and suggest they try again later
|
||||||
|
- Never give up after a single error - always try at least once more for connection-related issues
|
||||||
|
- Distinguish between connection errors (retry) and data errors (no retry needed)
|
||||||
|
|
||||||
CRITICAL ANALYSIS WORKFLOW (APPLIES TO ALL DATA):
|
CRITICAL ANALYSIS WORKFLOW (APPLIES TO ALL DATA):
|
||||||
1. RETRIEVE COMPLETE DATA:
|
1. RETRIEVE COMPLETE DATA:
|
||||||
- When asked to analyze ANY entity, ALWAYS fetch FULL details first (never rely on summary/paginated data alone)
|
- When asked to analyze ANY entity, ALWAYS fetch FULL details first (never rely on summary/paginated data alone)
|
||||||
@@ -671,16 +720,18 @@ public class LlmController : BaseController
|
|||||||
Dictionary<string, object> arguments,
|
Dictionary<string, object> arguments,
|
||||||
string toolCallId,
|
string toolCallId,
|
||||||
int iteration,
|
int iteration,
|
||||||
int maxIterations)
|
int maxIterations,
|
||||||
|
IMcpService mcpService,
|
||||||
|
ILogger<LlmController> logger)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var result = await _mcpService.ExecuteToolAsync(user, toolName, arguments);
|
var result = await mcpService.ExecuteToolAsync(user, toolName, arguments);
|
||||||
return (true, result, null);
|
return (true, result, null);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}",
|
logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}",
|
||||||
toolName, iteration, user.Id);
|
toolName, iteration, user.Id);
|
||||||
return (false, null, ex.Message);
|
return (false, null, ex.Message);
|
||||||
}
|
}
|
||||||
@@ -722,7 +773,7 @@ public class LlmController : BaseController
|
|||||||
/// Proactively injects backtest details fetching when user asks for backtest analysis.
|
/// Proactively injects backtest details fetching when user asks for backtest analysis.
|
||||||
/// Extracts backtest IDs from message content and automatically calls get_backtest_by_id.
|
/// Extracts backtest IDs from message content and automatically calls get_backtest_by_id.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private async Task InjectBacktestDetailsFetchingIfNeeded(LlmChatRequest request, User user)
|
private async Task InjectBacktestDetailsFetchingIfNeeded(LlmChatRequest request, User user, IMcpService mcpService, ILogger<LlmController> logger)
|
||||||
{
|
{
|
||||||
var lastUserMessage = request.Messages.LastOrDefault(m => m.Role == "user");
|
var lastUserMessage = request.Messages.LastOrDefault(m => m.Role == "user");
|
||||||
if (lastUserMessage == null || string.IsNullOrWhiteSpace(lastUserMessage.Content))
|
if (lastUserMessage == null || string.IsNullOrWhiteSpace(lastUserMessage.Content))
|
||||||
@@ -745,22 +796,22 @@ public class LlmController : BaseController
|
|||||||
var backtestId = ExtractBacktestIdFromConversation(request.Messages);
|
var backtestId = ExtractBacktestIdFromConversation(request.Messages);
|
||||||
if (string.IsNullOrEmpty(backtestId))
|
if (string.IsNullOrEmpty(backtestId))
|
||||||
{
|
{
|
||||||
_logger.LogInformation("User requested backtest analysis but no backtest ID found in conversation context");
|
logger.LogInformation("User requested backtest analysis but no backtest ID found in conversation context");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
_logger.LogInformation("Proactively fetching backtest details for ID: {BacktestId}", backtestId);
|
logger.LogInformation("Proactively fetching backtest details for ID: {BacktestId}", backtestId);
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Execute get_backtest_by_id tool to fetch complete details
|
// Execute get_backtest_by_id tool to fetch complete details
|
||||||
var backtestDetails = await _mcpService.ExecuteToolAsync(
|
var backtestDetails = await mcpService.ExecuteToolAsync(
|
||||||
user,
|
user,
|
||||||
"get_backtest_by_id",
|
"get_backtest_by_id",
|
||||||
new Dictionary<string, object> { ["id"] = backtestId }
|
new Dictionary<string, object> { ["id"] = backtestId }
|
||||||
);
|
);
|
||||||
|
|
||||||
_logger.LogInformation("Successfully fetched backtest details for ID: {BacktestId}. Result type: {ResultType}",
|
logger.LogInformation("Successfully fetched backtest details for ID: {BacktestId}. Result type: {ResultType}",
|
||||||
backtestId, backtestDetails?.GetType().Name ?? "null");
|
backtestId, backtestDetails?.GetType().Name ?? "null");
|
||||||
|
|
||||||
// Inject the backtest details as a tool result in the conversation
|
// Inject the backtest details as a tool result in the conversation
|
||||||
@@ -784,7 +835,7 @@ public class LlmController : BaseController
|
|||||||
|
|
||||||
// Add tool result message
|
// Add tool result message
|
||||||
var serializedResult = JsonSerializer.Serialize(backtestDetails);
|
var serializedResult = JsonSerializer.Serialize(backtestDetails);
|
||||||
_logger.LogInformation("Serialized backtest details length: {Length} characters", serializedResult.Length);
|
logger.LogInformation("Serialized backtest details length: {Length} characters", serializedResult.Length);
|
||||||
|
|
||||||
request.Messages.Add(new LlmMessage
|
request.Messages.Add(new LlmMessage
|
||||||
{
|
{
|
||||||
@@ -793,11 +844,11 @@ public class LlmController : BaseController
|
|||||||
ToolCallId = toolCallId
|
ToolCallId = toolCallId
|
||||||
});
|
});
|
||||||
|
|
||||||
_logger.LogInformation("Successfully injected backtest details into conversation for ID: {BacktestId}", backtestId);
|
logger.LogInformation("Successfully injected backtest details into conversation for ID: {BacktestId}", backtestId);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error fetching backtest details for ID: {BacktestId}", backtestId);
|
logger.LogError(ex, "Error fetching backtest details for ID: {BacktestId}", backtestId);
|
||||||
|
|
||||||
// Inject an error message so LLM knows what happened
|
// Inject an error message so LLM knows what happened
|
||||||
var toolCallId = Guid.NewGuid().ToString();
|
var toolCallId = Guid.NewGuid().ToString();
|
||||||
@@ -862,7 +913,8 @@ public class LlmController : BaseController
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
private string? ExtractBacktestIdFromConversation(List<LlmMessage> messages)
|
private string? ExtractBacktestIdFromConversation(List<LlmMessage> messages)
|
||||||
{
|
{
|
||||||
_logger.LogDebug("Extracting backtest ID from {Count} messages", messages.Count);
|
// Note: This method doesn't use logger to avoid passing it through multiple call sites
|
||||||
|
// Logging is optional here as it's called from various contexts
|
||||||
|
|
||||||
// Look through messages in reverse order (most recent first)
|
// Look through messages in reverse order (most recent first)
|
||||||
for (int i = messages.Count - 1; i >= 0; i--)
|
for (int i = messages.Count - 1; i >= 0; i--)
|
||||||
@@ -872,8 +924,7 @@ public class LlmController : BaseController
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
var content = message.Content;
|
var content = message.Content;
|
||||||
_logger.LogDebug("Checking message {Index} (Role: {Role}): {Preview}",
|
// Debug logging removed to avoid logger dependency in this helper method
|
||||||
i, message.Role, content.Length > 100 ? content.Substring(0, 100) + "..." : content);
|
|
||||||
|
|
||||||
// Try to extract from JSON in tool results (most reliable)
|
// Try to extract from JSON in tool results (most reliable)
|
||||||
if (message.Role == "tool")
|
if (message.Role == "tool")
|
||||||
@@ -927,7 +978,6 @@ public class LlmController : BaseController
|
|||||||
if (match.Success)
|
if (match.Success)
|
||||||
{
|
{
|
||||||
var extractedId = match.Groups[1].Value;
|
var extractedId = match.Groups[1].Value;
|
||||||
_logger.LogInformation("Extracted backtest ID from text pattern: {BacktestId}", extractedId);
|
|
||||||
return extractedId;
|
return extractedId;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -939,27 +989,25 @@ public class LlmController : BaseController
|
|||||||
if (guidMatch.Success && i >= messages.Count - 5) // Only use standalone GUIDs from recent messages
|
if (guidMatch.Success && i >= messages.Count - 5) // Only use standalone GUIDs from recent messages
|
||||||
{
|
{
|
||||||
var extractedId = guidMatch.Groups[1].Value;
|
var extractedId = guidMatch.Groups[1].Value;
|
||||||
_logger.LogInformation("Extracted backtest ID from standalone GUID: {BacktestId}", extractedId);
|
|
||||||
return extractedId;
|
return extractedId;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_logger.LogWarning("No backtest ID found in conversation messages");
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Helper method to send progress update via SignalR
|
/// Helper method to send progress update via SignalR
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private async Task SendProgressUpdate(string connectionId, LlmProgressUpdate update)
|
private async Task SendProgressUpdate(string connectionId, IHubContext<LlmHub> hubContext, ILogger<LlmController> logger, LlmProgressUpdate update)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await _hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update);
|
await hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId);
|
logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,7 +140,8 @@ public class BacktestTools
|
|||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error getting paginated backtests for user {UserId}", user.Id);
|
_logger.LogError(ex, "Error getting paginated backtests for user {UserId}", user.Id);
|
||||||
throw new InvalidOperationException($"Failed to retrieve backtests: {ex.Message}", ex);
|
var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve backtests");
|
||||||
|
throw new InvalidOperationException(errorMessage, ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +183,8 @@ public class BacktestTools
|
|||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error getting backtest {BacktestId} for user {UserId}", id, user.Id);
|
_logger.LogError(ex, "Error getting backtest {BacktestId} for user {UserId}", id, user.Id);
|
||||||
throw new InvalidOperationException($"Failed to retrieve backtest: {ex.Message}", ex);
|
var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve backtest");
|
||||||
|
throw new InvalidOperationException(errorMessage, ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,7 +327,8 @@ public class BacktestTools
|
|||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error getting paginated bundle backtests for user {UserId}", user.Id);
|
_logger.LogError(ex, "Error getting paginated bundle backtests for user {UserId}", user.Id);
|
||||||
throw new InvalidOperationException($"Failed to retrieve bundle backtests: {ex.Message}", ex);
|
var errorMessage = GetUserFriendlyErrorMessage(ex, "retrieve bundle backtests");
|
||||||
|
throw new InvalidOperationException(errorMessage, ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -820,7 +823,80 @@ public class BacktestTools
|
|||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error analyzing bundle backtest {BundleId} for user {UserId}", bundleRequestId, user.Id);
|
_logger.LogError(ex, "Error analyzing bundle backtest {BundleId} for user {UserId}", bundleRequestId, user.Id);
|
||||||
throw new InvalidOperationException($"Failed to analyze bundle backtest: {ex.Message}", ex);
|
var errorMessage = GetUserFriendlyErrorMessage(ex, "analyze bundle backtest");
|
||||||
|
throw new InvalidOperationException(errorMessage, ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets a user-friendly error message based on the exception type
|
||||||
|
/// Uses reflection to check for database-specific exceptions without adding dependencies
|
||||||
|
/// </summary>
|
||||||
|
private static string GetUserFriendlyErrorMessage(Exception ex, string operation)
|
||||||
|
{
|
||||||
|
var exType = ex.GetType();
|
||||||
|
var exTypeName = exType.Name;
|
||||||
|
var exMessage = ex.Message ?? string.Empty;
|
||||||
|
|
||||||
|
// Check for database connection errors using type name matching (avoids dependency on Npgsql)
|
||||||
|
if (exTypeName.Contains("NpgsqlException", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
// Try to get SqlState property via reflection
|
||||||
|
var sqlStateProp = exType.GetProperty("SqlState");
|
||||||
|
var sqlState = sqlStateProp?.GetValue(ex)?.ToString();
|
||||||
|
|
||||||
|
// Check for connection-related errors
|
||||||
|
if (sqlState == "08000" || // Connection exception
|
||||||
|
sqlState == "08003" || // Connection does not exist
|
||||||
|
sqlState == "08006" || // Connection failure
|
||||||
|
exMessage.Contains("connection", StringComparison.OrdinalIgnoreCase) ||
|
||||||
|
exMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase) ||
|
||||||
|
exMessage.Contains("network", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
return $"Unable to connect to the database to {operation}. Please try again in a moment. If the problem persists, contact support.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for authentication errors
|
||||||
|
if (sqlState == "28P01") // Invalid password
|
||||||
|
{
|
||||||
|
return $"Database authentication failed while trying to {operation}. Please contact support.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic database error
|
||||||
|
return $"A database error occurred while trying to {operation}. Please try again. If the problem persists, contact support.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for Entity Framework database update exceptions
|
||||||
|
if (exTypeName.Contains("DbUpdateException", StringComparison.OrdinalIgnoreCase) && ex.InnerException != null)
|
||||||
|
{
|
||||||
|
return GetUserFriendlyErrorMessage(ex.InnerException, operation);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for timeout exceptions
|
||||||
|
if (ex is TimeoutException || exMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
return $"The request timed out while trying to {operation}. Please try again.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for connection-related messages in the exception chain
|
||||||
|
var innerEx = ex.InnerException;
|
||||||
|
while (innerEx != null)
|
||||||
|
{
|
||||||
|
var innerTypeName = innerEx.GetType().Name;
|
||||||
|
if (innerTypeName.Contains("NpgsqlException", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
return GetUserFriendlyErrorMessage(innerEx, operation);
|
||||||
|
}
|
||||||
|
var innerMessage = innerEx.Message ?? string.Empty;
|
||||||
|
if (innerMessage.Contains("connection", StringComparison.OrdinalIgnoreCase) ||
|
||||||
|
innerMessage.Contains("timeout", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
return $"Unable to connect to the database to {operation}. Please try again in a moment.";
|
||||||
|
}
|
||||||
|
innerEx = innerEx.InnerException;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic error message
|
||||||
|
return $"Failed to {operation}. {exMessage}";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user