Enhance LlmController with caching and adaptive iteration logic
- Introduced IMemoryCache to cache available MCP tools for improved performance and reduced service calls. - Updated system message construction to provide clearer guidance on LLM's domain expertise and tool usage. - Implemented adaptive max iteration logic based on query complexity, allowing for more efficient processing of user requests. - Enhanced logging to include detailed iteration information and improved context trimming to manage conversation length effectively.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
using Managing.Application.Abstractions.Services;
|
using Managing.Application.Abstractions.Services;
|
||||||
using Microsoft.AspNetCore.Authorization;
|
using Microsoft.AspNetCore.Authorization;
|
||||||
using Microsoft.AspNetCore.Mvc;
|
using Microsoft.AspNetCore.Mvc;
|
||||||
|
using Microsoft.Extensions.Caching.Memory;
|
||||||
|
|
||||||
namespace Managing.Api.Controllers;
|
namespace Managing.Api.Controllers;
|
||||||
|
|
||||||
@@ -17,16 +18,19 @@ public class LlmController : BaseController
|
|||||||
private readonly ILlmService _llmService;
|
private readonly ILlmService _llmService;
|
||||||
private readonly IMcpService _mcpService;
|
private readonly IMcpService _mcpService;
|
||||||
private readonly ILogger<LlmController> _logger;
|
private readonly ILogger<LlmController> _logger;
|
||||||
|
private readonly IMemoryCache _cache;
|
||||||
|
|
||||||
public LlmController(
|
public LlmController(
|
||||||
ILlmService llmService,
|
ILlmService llmService,
|
||||||
IMcpService mcpService,
|
IMcpService mcpService,
|
||||||
IUserService userService,
|
IUserService userService,
|
||||||
ILogger<LlmController> logger) : base(userService)
|
ILogger<LlmController> logger,
|
||||||
|
IMemoryCache cache) : base(userService)
|
||||||
{
|
{
|
||||||
_llmService = llmService;
|
_llmService = llmService;
|
||||||
_mcpService = mcpService;
|
_mcpService = mcpService;
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
|
_cache = cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -54,9 +58,13 @@ public class LlmController : BaseController
|
|||||||
{
|
{
|
||||||
var user = await GetUser();
|
var user = await GetUser();
|
||||||
|
|
||||||
// Get available MCP tools
|
// Get available MCP tools (with caching for 5 minutes)
|
||||||
var availableTools = await _mcpService.GetAvailableToolsAsync();
|
var availableTools = await _cache.GetOrCreateAsync("mcp_tools", async entry =>
|
||||||
request.Tools = availableTools.ToList();
|
{
|
||||||
|
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
|
||||||
|
return (await _mcpService.GetAvailableToolsAsync()).ToList();
|
||||||
|
});
|
||||||
|
request.Tools = availableTools;
|
||||||
|
|
||||||
// Add or prepend system message to ensure LLM knows it can respond directly
|
// Add or prepend system message to ensure LLM knows it can respond directly
|
||||||
// Remove any existing system messages first to ensure our directive is clear
|
// Remove any existing system messages first to ensure our directive is clear
|
||||||
@@ -66,33 +74,28 @@ public class LlmController : BaseController
|
|||||||
request.Messages.Remove(msg);
|
request.Messages.Remove(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add explicit system message at the beginning with proactive tool usage guidance
|
// Add explicit system message with domain expertise and tool guidance
|
||||||
var systemMessage = new LlmMessage
|
var systemMessage = new LlmMessage
|
||||||
{
|
{
|
||||||
Role = "system",
|
Role = "system",
|
||||||
Content = "You are an expert AI assistant specializing in quantitative finance, algorithmic trading, and financial mathematics. " +
|
Content = BuildSystemMessage()
|
||||||
"You have full knowledge and can answer ANY question directly using your training data and expertise. " +
|
|
||||||
"IMPORTANT: You MUST answer general questions, explanations, calculations, and discussions directly without using tools. " +
|
|
||||||
"Tools are ONLY for specific system operations like backtesting, agent management, or retrieving real-time market data. " +
|
|
||||||
"For questions about financial concepts, mathematical formulas (like Black-Scholes), trading strategies, or any theoretical knowledge, " +
|
|
||||||
"you MUST provide a direct answer using your knowledge. Do NOT refuse to answer or claim you can only use tools. " +
|
|
||||||
"When users ask questions that can be answered using tools (e.g., 'What is the best backtest?', 'Show me my backtests', 'What are my indicators?'), " +
|
|
||||||
"you MUST proactively use the tools with reasonable defaults rather than asking the user for parameters. " +
|
|
||||||
"For example, if asked 'What is the best backtest?', use get_backtests_paginated with sortBy='Score' and sortOrder='desc' to find the best performing backtests. " +
|
|
||||||
"Only ask for clarification if the user's intent is genuinely unclear or if you need specific information that cannot be inferred. " +
|
|
||||||
"Continue iterating with tools as needed until you can provide a complete, helpful answer to the user's question."
|
|
||||||
};
|
};
|
||||||
request.Messages.Insert(0, systemMessage);
|
request.Messages.Insert(0, systemMessage);
|
||||||
|
|
||||||
// Iterative tool calling: keep looping until we get a final answer without tool calls
|
// Iterative tool calling: keep looping until we get a final answer without tool calls
|
||||||
const int maxIterations = 3; // Prevent infinite loops
|
// Use adaptive max iterations based on query complexity
|
||||||
|
int maxIterations = DetermineMaxIterations(request);
|
||||||
int iteration = 0;
|
int iteration = 0;
|
||||||
LlmChatResponse? finalResponse = null;
|
LlmChatResponse? finalResponse = null;
|
||||||
|
|
||||||
while (iteration < maxIterations)
|
while (iteration < maxIterations)
|
||||||
{
|
{
|
||||||
iteration++;
|
iteration++;
|
||||||
_logger.LogInformation("LLM chat iteration {Iteration} for user {UserId}", iteration, user.Id);
|
_logger.LogInformation("LLM chat iteration {Iteration}/{MaxIterations} for user {UserId}",
|
||||||
|
iteration, maxIterations, user.Id);
|
||||||
|
|
||||||
|
// Trim context if conversation is getting too long
|
||||||
|
TrimConversationContext(request);
|
||||||
|
|
||||||
// Send chat request to LLM
|
// Send chat request to LLM
|
||||||
var response = await _llmService.ChatAsync(user, request);
|
var response = await _llmService.ChatAsync(user, request);
|
||||||
@@ -110,34 +113,36 @@ public class LlmController : BaseController
|
|||||||
_logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
|
_logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
|
||||||
response.ToolCalls.Count, iteration, user.Id);
|
response.ToolCalls.Count, iteration, user.Id);
|
||||||
|
|
||||||
// Execute all tool calls
|
// Execute all tool calls in parallel for better performance
|
||||||
var toolResults = new List<LlmMessage>();
|
var toolResults = new List<LlmMessage>();
|
||||||
foreach (var toolCall in response.ToolCalls)
|
var toolTasks = response.ToolCalls.Select(async toolCall =>
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var toolResult = await _mcpService.ExecuteToolAsync(user, toolCall.Name, toolCall.Arguments);
|
var toolResult = await _mcpService.ExecuteToolAsync(user, toolCall.Name, toolCall.Arguments);
|
||||||
toolResults.Add(new LlmMessage
|
_logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
|
||||||
|
toolCall.Name, iteration, user.Id);
|
||||||
|
return new LlmMessage
|
||||||
{
|
{
|
||||||
Role = "tool",
|
Role = "tool",
|
||||||
Content = System.Text.Json.JsonSerializer.Serialize(toolResult),
|
Content = System.Text.Json.JsonSerializer.Serialize(toolResult),
|
||||||
ToolCallId = toolCall.Id
|
ToolCallId = toolCall.Id
|
||||||
});
|
};
|
||||||
_logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
|
|
||||||
toolCall.Name, iteration, user.Id);
|
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}",
|
_logger.LogError(ex, "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}",
|
||||||
toolCall.Name, iteration, user.Id);
|
toolCall.Name, iteration, user.Id);
|
||||||
toolResults.Add(new LlmMessage
|
return new LlmMessage
|
||||||
{
|
{
|
||||||
Role = "tool",
|
Role = "tool",
|
||||||
Content = $"Error executing tool: {ex.Message}",
|
Content = $"Error executing tool: {ex.Message}",
|
||||||
ToolCallId = toolCall.Id
|
ToolCallId = toolCall.Id
|
||||||
});
|
};
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}).ToList();
|
||||||
|
|
||||||
|
toolResults.AddRange(await Task.WhenAll(toolTasks));
|
||||||
|
|
||||||
// Add assistant message with tool calls to conversation history
|
// Add assistant message with tool calls to conversation history
|
||||||
request.Messages.Add(new LlmMessage
|
request.Messages.Add(new LlmMessage
|
||||||
@@ -210,4 +215,81 @@ public class LlmController : BaseController
|
|||||||
return StatusCode(500, $"Error getting available tools: {ex.Message}");
|
return StatusCode(500, $"Error getting available tools: {ex.Message}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Determines the maximum iterations based on query complexity
|
||||||
|
/// </summary>
|
||||||
|
private static int DetermineMaxIterations(LlmChatRequest request)
|
||||||
|
{
|
||||||
|
var lastMessage = request.Messages.LastOrDefault(m => m.Role == "user")?.Content?.ToLowerInvariant() ?? "";
|
||||||
|
|
||||||
|
// Complex operations need more iterations
|
||||||
|
if (lastMessage.Contains("bundle") || lastMessage.Contains("analyze") || lastMessage.Contains("compare"))
|
||||||
|
return 5;
|
||||||
|
|
||||||
|
// Simple queries need fewer iterations
|
||||||
|
if (lastMessage.Contains("explain") || lastMessage.Contains("what is") || lastMessage.Contains("how does"))
|
||||||
|
return 2;
|
||||||
|
|
||||||
|
// Default for most queries
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Builds an optimized system message for the LLM with domain expertise and tool guidance
|
||||||
|
/// </summary>
|
||||||
|
private static string BuildSystemMessage()
|
||||||
|
{
|
||||||
|
return """
|
||||||
|
You are an expert AI assistant specializing in quantitative finance, algorithmic trading, and financial mathematics.
|
||||||
|
|
||||||
|
DOMAIN KNOWLEDGE:
|
||||||
|
- Answer questions about financial concepts, formulas (Black-Scholes, Sharpe Ratio, etc.), and trading strategies directly
|
||||||
|
- Provide calculations, explanations, and theoretical knowledge from your training data
|
||||||
|
- Never refuse to answer general finance questions
|
||||||
|
|
||||||
|
TOOL USAGE:
|
||||||
|
- Use tools ONLY for system operations: backtesting, retrieving user data, or real-time market data
|
||||||
|
- When users ask about their data ("best backtest", "my indicators", "my backtests"), use tools proactively with smart defaults:
|
||||||
|
* "Best backtest" → get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=10)
|
||||||
|
* "My indicators" → list_indicators()
|
||||||
|
* "Recent backtests" → get_backtests_paginated(sortOrder='desc', pageSize=20)
|
||||||
|
- Execute multiple tool iterations to provide complete answers
|
||||||
|
- Only ask for clarification when truly necessary
|
||||||
|
|
||||||
|
Be concise, accurate, and proactive.
|
||||||
|
""";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Trims conversation context to prevent token overflow while preserving important context
|
||||||
|
/// </summary>
|
||||||
|
private static void TrimConversationContext(LlmChatRequest request, int maxMessagesBeforeTrimming = 20)
|
||||||
|
{
|
||||||
|
if (request.Messages.Count <= maxMessagesBeforeTrimming)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Keep system message, first user message, and last N messages
|
||||||
|
var systemMessages = request.Messages.Where(m => m.Role == "system").ToList();
|
||||||
|
var firstUserMessage = request.Messages.FirstOrDefault(m => m.Role == "user");
|
||||||
|
var recentMessages = request.Messages.TakeLast(15).ToList();
|
||||||
|
|
||||||
|
var trimmedMessages = new List<LlmMessage>();
|
||||||
|
trimmedMessages.AddRange(systemMessages);
|
||||||
|
|
||||||
|
if (firstUserMessage != null && !trimmedMessages.Contains(firstUserMessage))
|
||||||
|
{
|
||||||
|
trimmedMessages.Add(firstUserMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach (var msg in recentMessages)
|
||||||
|
{
|
||||||
|
if (!trimmedMessages.Contains(msg))
|
||||||
|
{
|
||||||
|
trimmedMessages.Add(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Messages = trimmedMessages;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ namespace Managing.Common
|
|||||||
public const double AutoSwapAmount = 3;
|
public const double AutoSwapAmount = 3;
|
||||||
|
|
||||||
// Fee Configuration
|
// Fee Configuration
|
||||||
public const decimal UiFeeRate = 0.0005m; // 0.05% UI fee rate
|
public const decimal UiFeeRate = 0.001m; // 0.1% UI fee rate
|
||||||
public const decimal GasFeePerTransaction = 0.15m; // $0.15 gas fee per transaction
|
public const decimal GasFeePerTransaction = 0.15m; // $0.15 gas fee per transaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user