Enhance LlmController with caching and adaptive iteration logic

- Introduced IMemoryCache to cache available MCP tools for improved performance and reduced service calls.
- Updated system message construction to provide clearer guidance on LLM's domain expertise and tool usage.
- Implemented adaptive max iteration logic based on query complexity, allowing for more efficient processing of user requests.
- Enhanced logging to include detailed iteration information and improved context trimming to manage conversation length effectively.
This commit is contained in:
2026-01-04 23:49:50 +07:00
parent 073111ddea
commit c78aedfee5
2 changed files with 110 additions and 28 deletions

View File

@@ -1,6 +1,7 @@
using Managing.Application.Abstractions.Services; using Managing.Application.Abstractions.Services;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Caching.Memory;
namespace Managing.Api.Controllers; namespace Managing.Api.Controllers;
@@ -17,16 +18,19 @@ public class LlmController : BaseController
private readonly ILlmService _llmService; private readonly ILlmService _llmService;
private readonly IMcpService _mcpService; private readonly IMcpService _mcpService;
private readonly ILogger<LlmController> _logger; private readonly ILogger<LlmController> _logger;
private readonly IMemoryCache _cache;
public LlmController( public LlmController(
ILlmService llmService, ILlmService llmService,
IMcpService mcpService, IMcpService mcpService,
IUserService userService, IUserService userService,
ILogger<LlmController> logger) : base(userService) ILogger<LlmController> logger,
IMemoryCache cache) : base(userService)
{ {
_llmService = llmService; _llmService = llmService;
_mcpService = mcpService; _mcpService = mcpService;
_logger = logger; _logger = logger;
_cache = cache;
} }
/// <summary> /// <summary>
@@ -54,9 +58,13 @@ public class LlmController : BaseController
{ {
var user = await GetUser(); var user = await GetUser();
// Get available MCP tools // Get available MCP tools (with caching for 5 minutes)
var availableTools = await _mcpService.GetAvailableToolsAsync(); var availableTools = await _cache.GetOrCreateAsync("mcp_tools", async entry =>
request.Tools = availableTools.ToList(); {
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
return (await _mcpService.GetAvailableToolsAsync()).ToList();
});
request.Tools = availableTools;
// Add or prepend system message to ensure LLM knows it can respond directly // Add or prepend system message to ensure LLM knows it can respond directly
// Remove any existing system messages first to ensure our directive is clear // Remove any existing system messages first to ensure our directive is clear
@@ -66,33 +74,28 @@ public class LlmController : BaseController
request.Messages.Remove(msg); request.Messages.Remove(msg);
} }
// Add explicit system message at the beginning with proactive tool usage guidance // Add explicit system message with domain expertise and tool guidance
var systemMessage = new LlmMessage var systemMessage = new LlmMessage
{ {
Role = "system", Role = "system",
Content = "You are an expert AI assistant specializing in quantitative finance, algorithmic trading, and financial mathematics. " + Content = BuildSystemMessage()
"You have full knowledge and can answer ANY question directly using your training data and expertise. " +
"IMPORTANT: You MUST answer general questions, explanations, calculations, and discussions directly without using tools. " +
"Tools are ONLY for specific system operations like backtesting, agent management, or retrieving real-time market data. " +
"For questions about financial concepts, mathematical formulas (like Black-Scholes), trading strategies, or any theoretical knowledge, " +
"you MUST provide a direct answer using your knowledge. Do NOT refuse to answer or claim you can only use tools. " +
"When users ask questions that can be answered using tools (e.g., 'What is the best backtest?', 'Show me my backtests', 'What are my indicators?'), " +
"you MUST proactively use the tools with reasonable defaults rather than asking the user for parameters. " +
"For example, if asked 'What is the best backtest?', use get_backtests_paginated with sortBy='Score' and sortOrder='desc' to find the best performing backtests. " +
"Only ask for clarification if the user's intent is genuinely unclear or if you need specific information that cannot be inferred. " +
"Continue iterating with tools as needed until you can provide a complete, helpful answer to the user's question."
}; };
request.Messages.Insert(0, systemMessage); request.Messages.Insert(0, systemMessage);
// Iterative tool calling: keep looping until we get a final answer without tool calls // Iterative tool calling: keep looping until we get a final answer without tool calls
const int maxIterations = 3; // Prevent infinite loops // Use adaptive max iterations based on query complexity
int maxIterations = DetermineMaxIterations(request);
int iteration = 0; int iteration = 0;
LlmChatResponse? finalResponse = null; LlmChatResponse? finalResponse = null;
while (iteration < maxIterations) while (iteration < maxIterations)
{ {
iteration++; iteration++;
_logger.LogInformation("LLM chat iteration {Iteration} for user {UserId}", iteration, user.Id); _logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}",
iteration, maxIterations, user.Id);
// Trim context if conversation is getting too long
TrimConversationContext(request);
// Send chat request to LLM // Send chat request to LLM
var response = await _llmService.ChatAsync(user, request); var response = await _llmService.ChatAsync(user, request);
@@ -110,34 +113,36 @@ public class LlmController : BaseController
_logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}", _logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
response.ToolCalls.Count, iteration, user.Id); response.ToolCalls.Count, iteration, user.Id);
// Execute all tool calls // Execute all tool calls in parallel for better performance
var toolResults = new List<LlmMessage>(); var toolResults = new List<LlmMessage>();
foreach (var toolCall in response.ToolCalls) var toolTasks = response.ToolCalls.Select(async toolCall =>
{ {
try try
{ {
var toolResult = await _mcpService.ExecuteToolAsync(user, toolCall.Name, toolCall.Arguments); var toolResult = await _mcpService.ExecuteToolAsync(user, toolCall.Name, toolCall.Arguments);
toolResults.Add(new LlmMessage _logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
toolCall.Name, iteration, user.Id);
return new LlmMessage
{ {
Role = "tool", Role = "tool",
Content = System.Text.Json.JsonSerializer.Serialize(toolResult), Content = System.Text.Json.JsonSerializer.Serialize(toolResult),
ToolCallId = toolCall.Id ToolCallId = toolCall.Id
}); };
_logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
toolCall.Name, iteration, user.Id);
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}", _logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}",
toolCall.Name, iteration, user.Id); toolCall.Name, iteration, user.Id);
toolResults.Add(new LlmMessage return new LlmMessage
{ {
Role = "tool", Role = "tool",
Content = $"Error executing tool: {ex.Message}", Content = $"Error executing tool: {ex.Message}",
ToolCallId = toolCall.Id ToolCallId = toolCall.Id
}); };
}
} }
}).ToList();
toolResults.AddRange(await Task.WhenAll(toolTasks));
// Add assistant message with tool calls to conversation history // Add assistant message with tool calls to conversation history
request.Messages.Add(new LlmMessage request.Messages.Add(new LlmMessage
@@ -210,4 +215,81 @@ public class LlmController : BaseController
return StatusCode(500, $"Error getting available tools: {ex.Message}"); return StatusCode(500, $"Error getting available tools: {ex.Message}");
} }
} }
/// <summary>
/// Determines the maximum iterations based on query complexity
/// </summary>
private static int DetermineMaxIterations(LlmChatRequest request)
{
var lastMessage = request.Messages.LastOrDefault(m => m.Role == "user")?.Content?.ToLowerInvariant() ?? "";
// Complex operations need more iterations
if (lastMessage.Contains("bundle") || lastMessage.Contains("analyze") || lastMessage.Contains("compare"))
return 5;
// Simple queries need fewer iterations
if (lastMessage.Contains("explain") || lastMessage.Contains("what is") || lastMessage.Contains("how does"))
return 2;
// Default for most queries
return 3;
}
/// <summary>
/// Builds an optimized system message for the LLM with domain expertise and tool guidance
/// </summary>
private static string BuildSystemMessage()
{
return """
You are an expert AI assistant specializing in quantitative finance, algorithmic trading, and financial mathematics.
DOMAIN KNOWLEDGE:
- Answer questions about financial concepts, formulas (Black-Scholes, Sharpe Ratio, etc.), and trading strategies directly
- Provide calculations, explanations, and theoretical knowledge from your training data
- Never refuse to answer general finance questions
TOOL USAGE:
- Use tools ONLY for system operations: backtesting, retrieving user data, or real-time market data
- When users ask about their data ("best backtest", "my indicators", "my backtests"), use tools proactively with smart defaults:
* "Best backtest" get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=10)
* "My indicators" list_indicators()
* "Recent backtests" get_backtests_paginated(sortOrder='desc', pageSize=20)
- Execute multiple tool iterations to provide complete answers
- Only ask for clarification when truly necessary
Be concise, accurate, and proactive.
""";
}
/// <summary>
/// Trims conversation context to prevent token overflow while preserving important context
/// </summary>
private static void TrimConversationContext(LlmChatRequest request, int maxMessagesBeforeTrimming = 20)
{
if (request.Messages.Count <= maxMessagesBeforeTrimming)
return;
// Keep system message, first user message, and last N messages
var systemMessages = request.Messages.Where(m => m.Role == "system").ToList();
var firstUserMessage = request.Messages.FirstOrDefault(m => m.Role == "user");
var recentMessages = request.Messages.TakeLast(15).ToList();
var trimmedMessages = new List<LlmMessage>();
trimmedMessages.AddRange(systemMessages);
if (firstUserMessage != null && !trimmedMessages.Contains(firstUserMessage))
{
trimmedMessages.Add(firstUserMessage);
}
foreach (var msg in recentMessages)
{
if (!trimmedMessages.Contains(msg))
{
trimmedMessages.Add(msg);
}
}
request.Messages = trimmedMessages;
}
} }

View File

@@ -114,7 +114,7 @@ namespace Managing.Common
public const double AutoSwapAmount = 3; public const double AutoSwapAmount = 3;
// Fee Configuration // Fee Configuration
public const decimal UiFeeRate = 0.0005m; // 0.05% UI fee rate public const decimal UiFeeRate = 0.001m; // 0.1% UI fee rate
public const decimal GasFeePerTransaction = 0.15m; // $0.15 gas fee per transaction public const decimal GasFeePerTransaction = 0.15m; // $0.15 gas fee per transaction
} }