Enhance LlmController and AiChat component for improved progress updates and message handling

- Introduced a new method in LlmController to generate descriptive messages for tool execution results, improving clarity in progress updates.
- Updated AiChat component to display progress messages in chat history, enhancing user experience during tool execution.
- Refactored progress indicator styling for better visual feedback and readability.
- Adjusted backtest query handling in LlmController to optimize iteration counts based on query type, improving performance and user interaction.
- Enhanced documentation for backtest tools in BacktestMcpTools to clarify usage and parameters, ensuring better understanding for developers.
This commit is contained in:
2026-01-06 23:25:14 +07:00
parent b7b4f1d12f
commit 1b08655dfa
4 changed files with 250 additions and 57 deletions

View File

@@ -348,10 +348,11 @@ public class LlmController : BaseController
logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}", logger.LogInformation("Successfully executed tool {ToolName} in iteration {Iteration} for user {UserId}",
toolCall.Name, iteration, user.Id); toolCall.Name, iteration, user.Id);
var resultMessage = GenerateToolResultMessage(toolCall.Name, result);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate
{ {
Type = "tool_result", Type = "tool_result",
Message = $"Tool {toolCall.Name} completed successfully", Message = resultMessage,
Iteration = iteration, Iteration = iteration,
MaxIterations = maxIterations, MaxIterations = maxIterations,
ToolName = toolCall.Name ToolName = toolCall.Name
@@ -678,13 +679,17 @@ public class LlmController : BaseController
if (lastMessage.Contains("bundle") || lastMessage.Contains("compare") || lastMessage.Contains("all backtests")) if (lastMessage.Contains("bundle") || lastMessage.Contains("compare") || lastMessage.Contains("all backtests"))
return 5; return 5;
// Backtest detail requests need 4 iterations (list → get_by_id → analyze → format) // Backtest detail requests with "analyze" or "detail" need more iterations for deep analysis
if (lastMessage.Contains("backtest") && if (lastMessage.Contains("backtest") &&
(lastMessage.Contains("detail") || lastMessage.Contains("analyze") || lastMessage.Contains("show") || (lastMessage.Contains("detail") || lastMessage.Contains("analyze") || lastMessage.Contains("position")))
lastMessage.Contains("this") || lastMessage.Contains("that") || lastMessage.Contains("best") ||
lastMessage.Contains("top") || lastMessage.Contains("recent")))
return 4; return 4;
// Simple backtest queries ("best", "top", "show") only need 2 iterations (fetch + respond)
if (lastMessage.Contains("backtest") &&
(lastMessage.Contains("best") || lastMessage.Contains("top") || lastMessage.Contains("show") ||
lastMessage.Contains("recent") || lastMessage.Contains("latest")))
return 2;
// General analysis queries // General analysis queries
if (lastMessage.Contains("analyze")) if (lastMessage.Contains("analyze"))
return 4; return 4;
@@ -713,10 +718,12 @@ public class LlmController : BaseController
TOOL USAGE: TOOL USAGE:
- Use tools ONLY for system operations: backtesting, retrieving user data, or real-time market data - Use tools ONLY for system operations: backtesting, retrieving user data, or real-time market data
- When users ask about their data, use tools proactively with smart defaults: - When users ask about their data, use tools proactively with smart defaults:
* "Best backtest" get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=10) * "Best backtest" get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=1)
* "Top 5 backtests" get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=5)
* "My indicators" list_indicators() * "My indicators" list_indicators()
* "Recent backtests" get_backtests_paginated(sortOrder='desc', pageSize=20) * "Recent backtests" get_backtests_paginated(sortBy='StartDate', sortOrder='desc', pageSize=10)
* "Bundle backtest analysis" analyze_bundle_backtest(bundleRequestId='X') * "Bundle backtest analysis" analyze_bundle_backtest(bundleRequestId='X')
- IMPORTANT: get_backtests_paginated returns summary data. Only call get_backtest_by_id if user explicitly asks for position details or deeper analysis.
ERROR HANDLING: ERROR HANDLING:
- If a tool returns a database connection error, wait a moment and retry once (these are often transient) - If a tool returns a database connection error, wait a moment and retry once (these are often transient)
@@ -739,12 +746,14 @@ public class LlmController : BaseController
- If multiple backtests were listed, use the most recently mentioned one or the top-ranked one - If multiple backtests were listed, use the most recently mentioned one or the top-ranked one
- NEVER ask user for IDs/names that were already provided in conversation - NEVER ask user for IDs/names that were already provided in conversation
3. BACKTEST DETAIL WORKFLOW: 3. BACKTEST DETAIL WORKFLOW (TOKEN-OPTIMIZED):
When user requests backtest details/analysis: When user requests backtest information:
a) If backtest ID is in conversation IMMEDIATELY call get_backtest_by_id(id) a) If backtest ID is in conversation AND user asks for positions/details call get_backtest_by_id(id)
b) If no ID but refers to "best/top" call get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=1) THEN get_backtest_by_id() b) If no ID but refers to "best/top N" call get_backtests_paginated(sortBy='Score', sortOrder='desc', pageSize=N)
c) If no ID but refers to "recent/latest" call get_backtests_paginated(sortOrder='desc', pageSize=1) THEN get_backtest_by_id() c) If no ID but refers to "recent/latest N" call get_backtests_paginated(sortBy='StartDate', sortOrder='desc', pageSize=N)
d) If completely ambiguous ask ONCE for clarification, then proceed d) For simple queries like "show my best backtest" get_backtests_paginated is sufficient (includes key metrics)
e) Only call get_backtest_by_id for DETAILED analysis when user explicitly needs position-level data
f) If completely ambiguous ask ONCE for clarification, then proceed
4. INDICATOR WORKFLOW: 4. INDICATOR WORKFLOW:
When user asks about indicators: When user asks about indicators:
@@ -1058,6 +1067,154 @@ public class LlmController : BaseController
return null; return null;
} }
/// <summary>
/// Generates a descriptive message for tool execution results
/// </summary>
private static string GenerateToolResultMessage(string toolName, object result)
{
try
{
// Try to parse result as JSON to extract meaningful information
var jsonResult = result as JsonElement? ?? JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize(result));
switch (toolName.ToLowerInvariant())
{
case "get_backtests_paginated":
if (jsonResult.TryGetProperty("items", out var backtestItems) && backtestItems.ValueKind == JsonValueKind.Array)
{
var count = backtestItems.GetArrayLength();
var totalCount = jsonResult.TryGetProperty("totalCount", out var total) ? total.GetInt32() : count;
return $"Retrieved {count} backtest(s) out of {totalCount} total";
}
break;
case "get_backtest_by_id":
if (jsonResult.TryGetProperty("name", out var backtestName))
{
var name = backtestName.GetString() ?? "Unknown";
var score = jsonResult.TryGetProperty("score", out var scoreVal) ? $" (Score: {scoreVal.GetDouble():F1})" : "";
return $"Retrieved backtest '{name}'{score}";
}
return "Retrieved backtest details";
case "get_bundle_backtests_paginated":
if (jsonResult.TryGetProperty("items", out var bundleItems) && bundleItems.ValueKind == JsonValueKind.Array)
{
var count = bundleItems.GetArrayLength();
var totalCount = jsonResult.TryGetProperty("totalCount", out var total) ? total.GetInt32() : count;
return $"Retrieved {count} bundle backtest(s) out of {totalCount} total";
}
break;
case "get_bundle_backtest_by_id":
if (jsonResult.TryGetProperty("botName", out var bundleName))
{
var name = bundleName.GetString() ?? "Unknown";
return $"Retrieved bundle backtest '{name}'";
}
return "Retrieved bundle backtest details";
case "analyze_bundle_backtest":
if (jsonResult.TryGetProperty("totalBacktests", out var totalBacktests))
{
var count = totalBacktests.GetInt32();
var avgScore = jsonResult.TryGetProperty("averageScore", out var score) ? $", Avg Score: {score.GetDouble():F1}" : "";
return $"Analyzed {count} backtest(s) in bundle{avgScore}";
}
return "Completed bundle backtest analysis";
case "list_indicators":
if (jsonResult.ValueKind == JsonValueKind.Array)
{
var count = jsonResult.GetArrayLength();
return $"Retrieved {count} indicator(s)";
}
break;
case "get_indicator_info":
if (jsonResult.TryGetProperty("type", out var indicatorType))
{
var type = indicatorType.GetString() ?? "Unknown";
return $"Retrieved info for indicator '{type}'";
}
return "Retrieved indicator information";
case "get_tickers":
if (jsonResult.ValueKind == JsonValueKind.Array)
{
var count = jsonResult.GetArrayLength();
return $"Retrieved {count} ticker(s)";
}
break;
case "get_candles":
if (jsonResult.ValueKind == JsonValueKind.Array)
{
var count = jsonResult.GetArrayLength();
return $"Retrieved {count} candle(s)";
}
break;
case "get_agents_paginated":
if (jsonResult.TryGetProperty("items", out var agentItems) && agentItems.ValueKind == JsonValueKind.Array)
{
var count = agentItems.GetArrayLength();
var totalCount = jsonResult.TryGetProperty("totalCount", out var total) ? total.GetInt32() : count;
return $"Retrieved {count} agent(s) out of {totalCount} total";
}
break;
case "get_online_agents":
if (jsonResult.ValueKind == JsonValueKind.Array)
{
var count = jsonResult.GetArrayLength();
return $"Found {count} online agent(s)";
}
break;
case "run_backtest":
if (jsonResult.TryGetProperty("backtestId", out var btId))
{
return $"Started backtest (ID: {btId.GetString()})";
}
return "Started backtest execution";
case "run_bundle_backtest":
if (jsonResult.TryGetProperty("bundleRequestId", out var bundleId))
{
return $"Started bundle backtest (ID: {bundleId.GetString()})";
}
return "Started bundle backtest execution";
case "delete_backtest":
case "delete_bundle_backtest":
return $"Successfully deleted {toolName.Replace("delete_", "").Replace("_", " ")}";
case "delete_backtests_by_ids":
if (jsonResult.TryGetProperty("deletedCount", out var deletedCount))
{
return $"Deleted {deletedCount.GetInt32()} backtest(s)";
}
return "Deleted backtests";
case "delete_backtests_by_filters":
if (jsonResult.TryGetProperty("deletedCount", out var filteredDeletedCount))
{
return $"Deleted {filteredDeletedCount.GetInt32()} backtest(s) matching filters";
}
return "Deleted backtests matching filters";
}
// Default message if no specific handler
return $"Tool {toolName} completed successfully";
}
catch (Exception)
{
// If parsing fails, return generic message
return $"Tool {toolName} completed successfully";
}
}
/// <summary> /// <summary>
/// Helper method to send progress update via SignalR /// Helper method to send progress update via SignalR
/// </summary> /// </summary>

View File

@@ -24,7 +24,7 @@ public class BacktestMcpTools : BaseMcpTool
new McpToolDefinition new McpToolDefinition
{ {
Name = "get_backtests_paginated", Name = "get_backtests_paginated",
Description = "Retrieves paginated backtests with filtering and sorting capabilities. Supports filters for score, winrate, drawdown, tickers, indicators, duration, and trading type.", Description = "Retrieves paginated backtests with filtering and sorting capabilities. This is the PRIMARY tool for finding backtests. Use this when user asks for 'best backtest' (pageSize=1), 'top 5 backtests' (pageSize=5), or browsing backtests (default pageSize=10). Only use get_backtest_by_id when you already have a specific backtest ID. Supports filters for score, winrate, drawdown, tickers, indicators, duration, and trading type.",
Parameters = new Dictionary<string, McpParameterDefinition> Parameters = new Dictionary<string, McpParameterDefinition>
{ {
["page"] = new McpParameterDefinition ["page"] = new McpParameterDefinition
@@ -37,9 +37,9 @@ public class BacktestMcpTools : BaseMcpTool
["pageSize"] = new McpParameterDefinition ["pageSize"] = new McpParameterDefinition
{ {
Type = "integer", Type = "integer",
Description = "Number of items per page (defaults to 50, max 100)", Description = "Number of items per page (defaults to 10, max 100). Use 1 for 'best/worst', 5 for 'top 5', 10 for general browsing.",
Required = false, Required = false,
DefaultValue = 50 DefaultValue = 10
}, },
["sortBy"] = new McpParameterDefinition ["sortBy"] = new McpParameterDefinition
{ {
@@ -126,7 +126,7 @@ public class BacktestMcpTools : BaseMcpTool
new McpToolDefinition new McpToolDefinition
{ {
Name = "get_backtest_by_id", Name = "get_backtest_by_id",
Description = "Retrieves a specific backtest by its ID for the authenticated user.", Description = "Retrieves a specific backtest by its ID with full position details. ONLY use this when you have a specific backtest ID (from previous conversation or from get_backtests_paginated results). DO NOT use this for queries like 'best backtest' or 'top backtests' - use get_backtests_paginated instead.",
Parameters = new Dictionary<string, McpParameterDefinition> Parameters = new Dictionary<string, McpParameterDefinition>
{ {
["id"] = new McpParameterDefinition ["id"] = new McpParameterDefinition
@@ -206,7 +206,7 @@ public class BacktestMcpTools : BaseMcpTool
new McpToolDefinition new McpToolDefinition
{ {
Name = "get_bundle_backtests_paginated", Name = "get_bundle_backtests_paginated",
Description = "Retrieves paginated bundle backtest requests with filtering and sorting capabilities.", Description = "Retrieves paginated bundle backtest requests with filtering and sorting capabilities. Use appropriate page size for token efficiency: 1 for latest bundle, 5 for top 5, 10 for general browsing.",
Parameters = new Dictionary<string, McpParameterDefinition> Parameters = new Dictionary<string, McpParameterDefinition>
{ {
["page"] = new McpParameterDefinition ["page"] = new McpParameterDefinition
@@ -219,9 +219,9 @@ public class BacktestMcpTools : BaseMcpTool
["pageSize"] = new McpParameterDefinition ["pageSize"] = new McpParameterDefinition
{ {
Type = "integer", Type = "integer",
Description = "Number of items per page (defaults to 50, max 100)", Description = "Number of items per page (defaults to 10, max 100). Use 1 for latest, 5 for top 5, 10 for general browsing.",
Required = false, Required = false,
DefaultValue = 50 DefaultValue = 10
}, },
["sortBy"] = new McpParameterDefinition ["sortBy"] = new McpParameterDefinition
{ {
@@ -455,7 +455,7 @@ public class BacktestMcpTools : BaseMcpTool
public async Task<object> ExecuteGetBacktestsPaginated(User user, Dictionary<string, object>? parameters) public async Task<object> ExecuteGetBacktestsPaginated(User user, Dictionary<string, object>? parameters)
{ {
var page = GetParameterValue<int>(parameters, "page", 1); var page = GetParameterValue<int>(parameters, "page", 1);
var pageSize = GetParameterValue<int>(parameters, "pageSize", 50); var pageSize = GetParameterValue<int>(parameters, "pageSize", 10);
var sortByString = GetParameterValue<string>(parameters, "sortBy", "Score"); var sortByString = GetParameterValue<string>(parameters, "sortBy", "Score");
var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc"); var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc");
var scoreMin = GetParameterValue<double?>(parameters, "scoreMin", null); var scoreMin = GetParameterValue<double?>(parameters, "scoreMin", null);
@@ -551,7 +551,7 @@ public class BacktestMcpTools : BaseMcpTool
public async Task<object> ExecuteGetBundleBacktestsPaginated(User user, Dictionary<string, object>? parameters) public async Task<object> ExecuteGetBundleBacktestsPaginated(User user, Dictionary<string, object>? parameters)
{ {
var page = GetParameterValue<int>(parameters, "page", 1); var page = GetParameterValue<int>(parameters, "page", 1);
var pageSize = GetParameterValue<int>(parameters, "pageSize", 50); var pageSize = GetParameterValue<int>(parameters, "pageSize", 10);
var sortByString = GetParameterValue<string>(parameters, "sortBy", "CreatedAt"); var sortByString = GetParameterValue<string>(parameters, "sortBy", "CreatedAt");
var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc"); var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc");
var statusString = GetParameterValue<string?>(parameters, "status", null); var statusString = GetParameterValue<string?>(parameters, "status", null);

View File

@@ -1,7 +1,7 @@
import { useState, useRef, useEffect } from 'react' import { useState, useRef, useEffect } from 'react'
import { LlmClient } from '../../generated/ManagingApi' import { LlmClient } from '../../generated/ManagingApi'
import { LlmMessage, LlmChatResponse, LlmProgressUpdate } from '../../generated/ManagingApiTypes' import { LlmMessage, LlmChatResponse } from '../../generated/ManagingApiTypes'
import { AiChatService } from '../../services/aiChatService' import { AiChatService, LlmProgressUpdate } from '../../services/aiChatService'
import useApiUrlStore from '../../app/store/apiStore' import useApiUrlStore from '../../app/store/apiStore'
interface Message { interface Message {
@@ -35,7 +35,7 @@ function AiChat({ onClose }: AiChatProps): JSX.Element {
const [historyIndex, setHistoryIndex] = useState<number>(-1) const [historyIndex, setHistoryIndex] = useState<number>(-1)
const [tempInput, setTempInput] = useState<string>('') const [tempInput, setTempInput] = useState<string>('')
const messagesEndRef = useRef<HTMLDivElement>(null) const messagesEndRef = useRef<HTMLDivElement>(null)
const { apiUrl, userToken } = useApiUrlStore() const { apiUrl } = useApiUrlStore()
useEffect(() => { useEffect(() => {
scrollToBottom() scrollToBottom()
@@ -112,6 +112,20 @@ function AiChat({ onClose }: AiChatProps): JSX.Element {
lastUpdate = update lastUpdate = update
setCurrentProgress(update) setCurrentProgress(update)
// Add progress messages to chat history (except final_response)
if (update.type !== 'final_response') {
const progressMessage: Message = {
role: 'progress',
content: update.message || '',
timestamp: new Date(),
progressType: update.type,
iteration: update.iteration,
maxIterations: update.maxIterations,
toolName: update.toolName
}
setMessages(prev => [...prev, progressMessage])
}
// Handle different update types // Handle different update types
if (update.type === 'error') { if (update.type === 'error') {
const errorMessage: Message = { const errorMessage: Message = {
@@ -270,8 +284,27 @@ function AiChat({ onClose }: AiChatProps): JSX.Element {
{messages.filter(m => m.role !== 'system').map((message, index) => ( {messages.filter(m => m.role !== 'system').map((message, index) => (
<div <div
key={index} key={index}
className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`} className={`flex ${
message.role === 'user'
? 'justify-end'
: message.role === 'progress'
? 'justify-start pl-4'
: 'justify-start'
}`}
> >
{message.role === 'progress' ? (
<div className="max-w-[80%] p-2 rounded-lg bg-base-200/40 border-l-2 border-info/40">
<ProgressIndicator
progress={{
type: message.progressType || 'info',
message: message.content,
iteration: message.iteration,
maxIterations: message.maxIterations,
toolName: message.toolName
} as LlmProgressUpdate}
/>
</div>
) : (
<div <div
className={`max-w-[80%] p-3 rounded-lg ${ className={`max-w-[80%] p-3 rounded-lg ${
message.role === 'user' message.role === 'user'
@@ -284,6 +317,7 @@ function AiChat({ onClose }: AiChatProps): JSX.Element {
{message.timestamp.toLocaleTimeString()} {message.timestamp.toLocaleTimeString()}
</p> </p>
</div> </div>
)}
</div> </div>
))} ))}
@@ -385,31 +419,20 @@ function ProgressIndicator({ progress }: { progress: LlmProgressUpdate }): JSX.E
} }
return ( return (
<div className="space-y-2 opacity-80"> <div className="space-y-2">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span className="text-lg opacity-70">{getIcon()}</span> <span className="text-base">{getIcon()}</span>
<span className={`text-sm font-normal ${getColor()} opacity-75`}> <span className={`text-xs font-normal ${getColor()}`}>
{progress.message} {progress.message}
</span> </span>
</div> </div>
{progress.iteration && progress.maxIterations && (
<div className="flex items-center gap-2 text-xs text-base-content/40">
<progress
className="progress progress-primary w-32 h-1.5 opacity-60"
value={progress.iteration}
max={progress.maxIterations}
/>
<span className="opacity-60">Iteration {progress.iteration}/{progress.maxIterations}</span>
</div>
)}
{progress.toolName && ( {progress.toolName && (
<div className="text-xs text-base-content/40 mt-1"> <div className="text-xs text-base-content/60 mt-1">
<span className="font-mono bg-base-300/50 px-2 py-1 rounded opacity-70"> <span className="font-mono bg-base-300/50 px-1.5 py-0.5 rounded text-xs">
{progress.toolName} {progress.toolName}
{progress.toolArguments && Object.keys(progress.toolArguments).length > 0 && ( {progress.toolArguments && Object.keys(progress.toolArguments).length > 0 && (
<span className="ml-1 opacity-50"> <span className="ml-1 opacity-70">
({Object.keys(progress.toolArguments).length} args) ({Object.keys(progress.toolArguments).length} args)
</span> </span>
)} )}
@@ -418,7 +441,7 @@ function ProgressIndicator({ progress }: { progress: LlmProgressUpdate }): JSX.E
)} )}
{progress.error && ( {progress.error && (
<div className="text-xs text-error mt-1 opacity-80"> <div className="text-xs text-error mt-1">
{progress.error} {progress.error}
</div> </div>
)} )}

View File

@@ -1,8 +1,21 @@
import { HubConnection, HubConnectionBuilder } from '@microsoft/signalr' import { HubConnection, HubConnectionBuilder } from '@microsoft/signalr'
import { LlmClient } from '../generated/ManagingApi' import { LlmClient } from '../generated/ManagingApi'
import { LlmChatRequest, LlmChatResponse, LlmMessage, LlmProgressUpdate } from '../generated/ManagingApiTypes' import { LlmChatRequest, LlmChatResponse, LlmMessage } from '../generated/ManagingApiTypes'
import { Cookies } from 'react-cookie' import { Cookies } from 'react-cookie'
export interface LlmProgressUpdate {
type: string
message: string
iteration?: number
maxIterations?: number
toolName?: string
toolArguments?: Record<string, any>
content?: string
response?: LlmChatResponse
error?: string
timestamp?: Date
}
export class AiChatService { export class AiChatService {
private llmClient: LlmClient private llmClient: LlmClient
private baseUrl: string private baseUrl: string