Refactor LlmController and AiChatService for SSE integration and Redis support

- Updated LlmController to implement a new SSE endpoint for streaming LLM progress updates, utilizing Redis pub/sub for real-time communication.
- Removed SignalR dependencies from AiChatService, replacing them with SSE logic for message streaming.
- Enhanced error handling and logging for Redis interactions, ensuring robust feedback during streaming operations.
- Adjusted request models and methods to accommodate the new streaming architecture, improving clarity and maintainability.
This commit is contained in:
2026-01-07 18:13:18 +07:00
parent 35928d5528
commit 48fedb1247
5 changed files with 287 additions and 208 deletions

View File

@@ -178,3 +178,4 @@ Task Environment (offset ports)

View File

@@ -2,13 +2,12 @@ using System.Net.Http;
using System.Text.Json; using System.Text.Json;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Managing.Application.Abstractions.Services; using Managing.Application.Abstractions.Services;
using Managing.Application.Hubs;
using Managing.Domain.Users; using Managing.Domain.Users;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using StackExchange.Redis;
namespace Managing.Api.Controllers; namespace Managing.Api.Controllers;
@@ -26,8 +25,8 @@ public class LlmController : BaseController
private readonly IMcpService _mcpService; private readonly IMcpService _mcpService;
private readonly ILogger<LlmController> _logger; private readonly ILogger<LlmController> _logger;
private readonly IMemoryCache _cache; private readonly IMemoryCache _cache;
private readonly IHubContext<LlmHub> _hubContext;
private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IRedisConnectionService _redisService;
public LlmController( public LlmController(
ILlmService llmService, ILlmService llmService,
@@ -35,38 +34,120 @@ public class LlmController : BaseController
IUserService userService, IUserService userService,
ILogger<LlmController> logger, ILogger<LlmController> logger,
IMemoryCache cache, IMemoryCache cache,
IHubContext<LlmHub> hubContext, IServiceScopeFactory serviceScopeFactory,
IServiceScopeFactory serviceScopeFactory) : base(userService) IRedisConnectionService redisService) : base(userService)
{ {
_llmService = llmService; _llmService = llmService;
_mcpService = mcpService; _mcpService = mcpService;
_logger = logger; _logger = logger;
_cache = cache; _cache = cache;
_hubContext = hubContext;
_serviceScopeFactory = serviceScopeFactory; _serviceScopeFactory = serviceScopeFactory;
_redisService = redisService;
} }
/// <summary> /// <summary>
/// Sends a chat message to an LLM with streaming progress updates via SignalR. /// SSE endpoint for streaming LLM progress updates.
/// Provides real-time updates about iterations, tool calls, and progress similar to Cursor/Claude. /// Subscribes to Redis pub/sub channel for the given streamId and streams updates to the client.
/// Progress updates are sent via SignalR to the specified connectionId.
/// </summary> /// </summary>
/// <param name="request">The chat request with messages, optional provider/API key, and SignalR connectionId</param> /// <param name="streamId">Unique stream identifier</param>
/// <returns>OK status - updates are sent via SignalR</returns> /// <returns>SSE stream of progress updates</returns>
[HttpGet]
[Route("stream/{streamId}")]
[Produces("text/event-stream")]
public async Task StreamUpdates(string streamId, CancellationToken cancellationToken)
{
// Verify user is authenticated
try
{
var user = await GetUser();
if (user == null)
{
Response.StatusCode = 401;
await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Unauthorized\",\"error\":\"Authentication required\"}\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
return;
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error authenticating user for SSE stream {StreamId}", streamId);
Response.StatusCode = 401;
await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Authentication failed\",\"error\":\"Unable to authenticate user\"}\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
return;
}
Response.ContentType = "text/event-stream";
Response.Headers["Cache-Control"] = "no-cache";
Response.Headers["Connection"] = "keep-alive";
Response.Headers["X-Accel-Buffering"] = "no"; // Disable nginx buffering
var redis = _redisService.GetConnection();
if (redis == null)
{
await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Redis unavailable\"}\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
return;
}
var subscriber = redis.GetSubscriber();
var channel = RedisChannel.Literal($"llm-stream:{streamId}");
// Subscribe to updates for this stream
await subscriber.SubscribeAsync(channel, async (redisChannel, message) =>
{
try
{
var json = message.ToString();
await Response.WriteAsync($"data: {json}\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error writing SSE message for stream {StreamId}", streamId);
}
});
// Send initial connection message
await Response.WriteAsync($"data: {{\"type\":\"connected\",\"streamId\":\"{streamId}\"}}\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
_logger.LogInformation("SSE connection established for stream {StreamId}", streamId);
// Keep connection alive until cancelled
try
{
await Task.Delay(Timeout.Infinite, cancellationToken);
}
catch (OperationCanceledException)
{
// Client disconnected
await subscriber.UnsubscribeAsync(channel);
_logger.LogInformation("SSE connection closed for stream {StreamId}", streamId);
}
}
/// <summary>
/// Sends a chat message to an LLM with streaming progress updates via SSE/Redis.
/// Provides real-time updates about iterations, tool calls, and progress similar to Cursor/Claude.
/// Progress updates are published to Redis pub/sub and streamed via SSE.
/// </summary>
/// <param name="request">The chat request with messages, optional provider/API key, and streamId</param>
/// <returns>OK status - updates are sent via SSE</returns>
[HttpPost] [HttpPost]
[Route("ChatStream")] [Route("ChatStream")]
[Consumes("application/json")] [Consumes("application/json")]
[Produces("application/json")] [Produces("application/json")]
public async Task<ActionResult> ChatStream([FromBody] LlmChatStreamRequest request) public async Task<ActionResult> ChatStream([FromBody] LlmChatStreamRequest request)
{ {
if (request == null || string.IsNullOrWhiteSpace(request.ConnectionId)) if (request == null || string.IsNullOrWhiteSpace(request.StreamId))
{ {
return BadRequest("Chat request and connectionId are required"); return BadRequest("Chat request and streamId are required");
} }
if (request.Messages == null || !request.Messages.Any()) if (request.Messages == null || !request.Messages.Any())
{ {
await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate
{ {
Type = "error", Type = "error",
Message = "At least one message is required", Message = "At least one message is required",
@@ -87,7 +168,7 @@ public class LlmController : BaseController
if (user == null) if (user == null)
{ {
await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate
{ {
Type = "error", Type = "error",
Message = "Error authenticating user", Message = "Error authenticating user",
@@ -109,58 +190,56 @@ public class LlmController : BaseController
var mcpService = scope.ServiceProvider.GetRequiredService<IMcpService>(); var mcpService = scope.ServiceProvider.GetRequiredService<IMcpService>();
var userService = scope.ServiceProvider.GetRequiredService<IUserService>(); var userService = scope.ServiceProvider.GetRequiredService<IUserService>();
var cache = scope.ServiceProvider.GetRequiredService<IMemoryCache>(); var cache = scope.ServiceProvider.GetRequiredService<IMemoryCache>();
var hubContext = scope.ServiceProvider.GetRequiredService<IHubContext<LlmHub>>(); var redisService = scope.ServiceProvider.GetRequiredService<IRedisConnectionService>();
var logger = scope.ServiceProvider.GetRequiredService<ILogger<LlmController>>(); var logger = scope.ServiceProvider.GetRequiredService<ILogger<LlmController>>();
// Reload user from the scoped service to ensure we have a valid user object // Reload user from the scoped service to ensure we have a valid user object
var scopedUser = await userService.GetUserByIdAsync(userId); var scopedUser = await userService.GetUserByIdAsync(userId);
if (scopedUser == null) if (scopedUser == null)
{ {
await hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate
new LlmProgressUpdate {
{ Type = "error",
Type = "error", Message = "User not found",
Message = "User not found", Error = "Unable to authenticate user"
Error = "Unable to authenticate user" });
});
return; return;
} }
await ChatStreamInternal(request, scopedUser, request.ConnectionId, llmService, mcpService, cache, await ChatStreamInternal(request, scopedUser, request.StreamId, llmService, mcpService, cache,
hubContext, logger); redisService, logger);
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogError(ex, "Error processing chat stream for connection {ConnectionId}", _logger.LogError(ex, "Error processing chat stream for stream {StreamId}",
request.ConnectionId); request.StreamId);
try try
{ {
await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate
new LlmProgressUpdate {
{ Type = "error",
Type = "error", Message = $"Error processing chat: {ex.Message}",
Message = $"Error processing chat: {ex.Message}", Error = ex.Message
Error = ex.Message });
});
} }
catch (Exception hubEx) catch (Exception redisEx)
{ {
_logger.LogError(hubEx, "Error sending error message to SignalR client"); _logger.LogError(redisEx, "Error publishing error message to Redis");
} }
} }
}); });
return Ok(new { Message = "Chat stream started", ConnectionId = request.ConnectionId }); return Ok(new { Message = "Chat stream started", StreamId = request.StreamId });
} }
private async Task ChatStreamInternal( private async Task ChatStreamInternal(
LlmChatStreamRequest request, LlmChatStreamRequest request,
User user, User user,
string connectionId, string streamId,
ILlmService llmService, ILlmService llmService,
IMcpService mcpService, IMcpService mcpService,
IMemoryCache cache, IMemoryCache cache,
IHubContext<LlmHub> hubContext, IRedisConnectionService redisService,
ILogger<LlmController> logger) ILogger<LlmController> logger)
{ {
// Convert to LlmChatRequest for service calls // Convert to LlmChatRequest for service calls
@@ -175,7 +254,7 @@ public class LlmController : BaseController
Tools = request.Tools Tools = request.Tools
}; };
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Initializing conversation and loading available tools..." Message = "Initializing conversation and loading available tools..."
@@ -189,7 +268,7 @@ public class LlmController : BaseController
}); });
chatRequest.Tools = availableTools; chatRequest.Tools = availableTools;
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = $"Loaded {availableTools.Count} available tools. Preparing system context..." Message = $"Loaded {availableTools.Count} available tools. Preparing system context..."
@@ -225,7 +304,7 @@ public class LlmController : BaseController
const int DelayAfterToolCallsMs = 1000; // Additional delay after tool calls before next LLM call const int DelayAfterToolCallsMs = 1000; // Additional delay after tool calls before next LLM call
const int MaxRedundantDetections = 2; // Maximum times we'll detect redundant calls before forcing final response const int MaxRedundantDetections = 2; // Maximum times we'll detect redundant calls before forcing final response
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..." Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..."
@@ -238,7 +317,7 @@ public class LlmController : BaseController
// Get the last user question once per iteration to avoid scope conflicts // Get the last user question once per iteration to avoid scope conflicts
var lastUserQuestion = chatRequest.Messages.LastOrDefault(m => m.Role == "user")?.Content ?? "the user's question"; var lastUserQuestion = chatRequest.Messages.LastOrDefault(m => m.Role == "user")?.Content ?? "the user's question";
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "iteration_start", Type = "iteration_start",
Message = "Analyzing your request and determining next steps...", Message = "Analyzing your request and determining next steps...",
@@ -252,7 +331,7 @@ public class LlmController : BaseController
// Add delay between iterations to avoid rapid bursts and rate limiting // Add delay between iterations to avoid rapid bursts and rate limiting
if (iteration > 1) if (iteration > 1)
{ {
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Waiting briefly to respect rate limits...", Message = "Waiting briefly to respect rate limits...",
@@ -266,7 +345,7 @@ public class LlmController : BaseController
TrimConversationContext(chatRequest); TrimConversationContext(chatRequest);
// Send chat request to LLM with retry logic for rate limits // Send chat request to LLM with retry logic for rate limits
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Sending request to LLM...", Message = "Sending request to LLM...",
@@ -286,7 +365,7 @@ public class LlmController : BaseController
// Rate limit hit - wait longer before retrying // Rate limit hit - wait longer before retrying
logger.LogWarning("Rate limit hit (429) in iteration {Iteration}. Waiting 10 seconds before retry...", logger.LogWarning("Rate limit hit (429) in iteration {Iteration}. Waiting 10 seconds before retry...",
iteration); iteration);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Rate limit reached. Waiting before retrying...", Message = "Rate limit reached. Waiting before retrying...",
@@ -317,7 +396,7 @@ public class LlmController : BaseController
logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}", logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}",
iteration, user.Id); iteration, user.Id);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Received final response. Preparing answer...", Message = "Received final response. Preparing answer...",
@@ -336,7 +415,7 @@ public class LlmController : BaseController
logger.LogWarning("LLM requested {Count} redundant tool calls in iteration {Iteration}: {Tools} (Detection #{DetectionCount})", logger.LogWarning("LLM requested {Count} redundant tool calls in iteration {Iteration}: {Tools} (Detection #{DetectionCount})",
redundantCalls.Count, iteration, string.Join(", ", redundantCalls.Select(t => t.Name)), redundantCallDetections); redundantCalls.Count, iteration, string.Join(", ", redundantCalls.Select(t => t.Name)), redundantCallDetections);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Detected redundant tool calls. Using cached data...", Message = "Detected redundant tool calls. Using cached data...",
@@ -350,7 +429,7 @@ public class LlmController : BaseController
logger.LogWarning("Reached maximum redundant call detections ({MaxDetections}). Removing tools to force final response.", logger.LogWarning("Reached maximum redundant call detections ({MaxDetections}). Removing tools to force final response.",
MaxRedundantDetections); MaxRedundantDetections);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Multiple redundant tool calls detected. Forcing final response...", Message = "Multiple redundant tool calls detected. Forcing final response...",
@@ -433,7 +512,7 @@ public class LlmController : BaseController
logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}", logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}",
response.ToolCalls.Count, iteration, user.Id); response.ToolCalls.Count, iteration, user.Id);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...", Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...",
@@ -445,7 +524,7 @@ public class LlmController : BaseController
var toolResults = new List<LlmMessage>(); var toolResults = new List<LlmMessage>();
foreach (var toolCall in response.ToolCalls) foreach (var toolCall in response.ToolCalls)
{ {
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "tool_call", Type = "tool_call",
Message = $"Calling tool: {toolCall.Name}", Message = $"Calling tool: {toolCall.Name}",
@@ -465,7 +544,7 @@ public class LlmController : BaseController
toolCall.Name, iteration, user.Id); toolCall.Name, iteration, user.Id);
var resultMessage = GenerateToolResultMessage(toolCall.Name, result); var resultMessage = GenerateToolResultMessage(toolCall.Name, result);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "tool_result", Type = "tool_result",
Message = resultMessage, Message = resultMessage,
@@ -487,7 +566,7 @@ public class LlmController : BaseController
"Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}", "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}",
toolCall.Name, iteration, user.Id, error); toolCall.Name, iteration, user.Id, error);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "tool_result", Type = "tool_result",
Message = $"Tool {toolCall.Name} encountered an error: {error}", Message = $"Tool {toolCall.Name} encountered an error: {error}",
@@ -506,7 +585,7 @@ public class LlmController : BaseController
} }
} }
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "All tools completed. Analyzing results...", Message = "All tools completed. Analyzing results...",
@@ -546,7 +625,7 @@ public class LlmController : BaseController
"Reached max iterations ({MaxIterations}) for user {UserId}. Forcing final response without tools.", "Reached max iterations ({MaxIterations}) for user {UserId}. Forcing final response without tools.",
maxIterations, user.Id); maxIterations, user.Id);
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "thinking", Type = "thinking",
Message = "Reached maximum iterations. Preparing final response with available data...", Message = "Reached maximum iterations. Preparing final response with available data...",
@@ -617,7 +696,7 @@ public class LlmController : BaseController
user.Id, iteration, finalResponse.Content?.Length ?? 0, !string.IsNullOrWhiteSpace(finalResponse.Content)); user.Id, iteration, finalResponse.Content?.Length ?? 0, !string.IsNullOrWhiteSpace(finalResponse.Content));
// Send final response // Send final response
await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate await PublishProgressUpdate(streamId, new LlmProgressUpdate
{ {
Type = "final_response", Type = "final_response",
Message = "Analysis complete!", Message = "Analysis complete!",
@@ -626,7 +705,7 @@ public class LlmController : BaseController
MaxIterations = maxIterations MaxIterations = maxIterations
}); });
logger.LogInformation("Final response sent successfully to connection {ConnectionId} for user {UserId}", connectionId, user.Id); logger.LogInformation("Final response sent successfully to stream {StreamId} for user {UserId}", streamId, user.Id);
} }
/// <summary> /// <summary>
@@ -1562,18 +1641,28 @@ public class LlmController : BaseController
} }
/// <summary> /// <summary>
/// Helper method to send progress update via SignalR /// Helper method to publish progress update to Redis pub/sub
/// </summary> /// </summary>
private async Task SendProgressUpdate(string connectionId, IHubContext<LlmHub> hubContext, private async Task PublishProgressUpdate(string streamId, LlmProgressUpdate update)
ILogger<LlmController> logger, LlmProgressUpdate update)
{ {
try try
{ {
await hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update); var redis = _redisService.GetConnection();
if (redis == null)
{
_logger.LogWarning("Redis not available, cannot publish progress update for stream {StreamId}", streamId);
return;
}
var subscriber = redis.GetSubscriber();
var channel = RedisChannel.Literal($"llm-stream:{streamId}");
var message = JsonSerializer.Serialize(update);
await subscriber.PublishAsync(channel, message);
} }
catch (Exception ex) catch (Exception ex)
{ {
logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId); _logger.LogError(ex, "Error publishing progress update to Redis for stream {StreamId}", streamId);
} }
} }
@@ -1653,12 +1742,12 @@ public class LlmController : BaseController
} }
/// <summary> /// <summary>
/// Request model for LLM chat streaming via SignalR /// Request model for LLM chat streaming via SSE/Redis
/// </summary> /// </summary>
public class LlmChatStreamRequest : LlmChatRequest public class LlmChatStreamRequest : LlmChatRequest
{ {
/// <summary> /// <summary>
/// SignalR connection ID to send progress updates to /// Stream ID for SSE connection and Redis pub/sub channel
/// </summary> /// </summary>
public string ConnectionId { get; set; } = string.Empty; public string StreamId { get; set; } = string.Empty;
} }

View File

@@ -306,20 +306,26 @@ builder.Services
{ {
OnMessageReceived = context => OnMessageReceived = context =>
{ {
var path = context.Request.Path.Value ?? "";
var pathLower = path.ToLowerInvariant();
// Skip token extraction for anonymous endpoints to avoid validation errors // Skip token extraction for anonymous endpoints to avoid validation errors
var path = context.Request.Path.Value?.ToLower() ?? ""; if (!string.IsNullOrEmpty(pathLower) && (pathLower.EndsWith("/create-token") || pathLower.EndsWith("/authenticate")))
if (!string.IsNullOrEmpty(path) && (path.EndsWith("/create-token") || path.EndsWith("/authenticate")))
{ {
// Clear any token to prevent validation on anonymous endpoints // Clear any token to prevent validation on anonymous endpoints
context.Token = null; context.Token = null;
return Task.CompletedTask; return Task.CompletedTask;
} }
// Extract token from query string for SignalR connections // Extract token from query string for SignalR connections and SSE endpoints
// SignalR uses access_token query parameter for WebSocket connections // SignalR uses access_token query parameter for WebSocket connections
if (path.Contains("/bothub") || path.Contains("/backtesthub") || path.Contains("/llmhub")) // SSE endpoints also use access_token since EventSource doesn't support custom headers
if (pathLower.Contains("/bothub") ||
pathLower.Contains("/backtesthub") ||
pathLower.Contains("/llmhub") ||
pathLower.Contains("/llm/stream"))
{ {
var accessToken = context.Request.Query["access_token"]; var accessToken = context.Request.Query["access_token"].FirstOrDefault();
if (!string.IsNullOrEmpty(accessToken)) if (!string.IsNullOrEmpty(accessToken))
{ {
context.Token = accessToken; context.Token = accessToken;

View File

@@ -16,3 +16,4 @@ public class PrivyRevokeAllApprovalsResponse
public string? Error { get; set; } public string? Error { get; set; }
} }

View File

@@ -1,4 +1,3 @@
import { HubConnection, HubConnectionBuilder } from '@microsoft/signalr'
import { LlmClient } from '../generated/ManagingApi' import { LlmClient } from '../generated/ManagingApi'
import { LlmChatRequest, LlmChatResponse, LlmMessage } from '../generated/ManagingApiTypes' import { LlmChatRequest, LlmChatResponse, LlmMessage } from '../generated/ManagingApiTypes'
import { Cookies } from 'react-cookie' import { Cookies } from 'react-cookie'
@@ -19,7 +18,6 @@ export interface LlmProgressUpdate {
export class AiChatService { export class AiChatService {
private llmClient: LlmClient private llmClient: LlmClient
private baseUrl: string private baseUrl: string
private hubConnection: HubConnection | null = null
constructor(llmClient: LlmClient, baseUrl: string) { constructor(llmClient: LlmClient, baseUrl: string) {
this.llmClient = llmClient this.llmClient = llmClient
@@ -27,102 +25,10 @@ export class AiChatService {
} }
/** /**
* Creates and connects to SignalR hub for LLM chat streaming * Generates a unique stream ID for SSE connection
*/ */
async connectToHub(): Promise<HubConnection> { private generateStreamId(): string {
if (this.hubConnection?.state === 'Connected') { return `stream-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`
return this.hubConnection
}
// Clean up existing connection if any
if (this.hubConnection) {
try {
await this.hubConnection.stop()
} catch (e) {
// Ignore stop errors
}
this.hubConnection = null
}
const cookies = new Cookies()
const bearerToken = cookies.get('token')
if (!bearerToken) {
throw new Error('No authentication token found. Please log in first.')
}
// Ensure baseUrl doesn't have trailing slash
const baseUrl = this.baseUrl.endsWith('/') ? this.baseUrl.slice(0, -1) : this.baseUrl
const hubUrl = `${baseUrl}/llmhub`
console.log('Connecting to SignalR hub:', hubUrl)
const connection = new HubConnectionBuilder()
.withUrl(hubUrl, {
// Pass token via query string (standard for SignalR WebSocket connections)
// SignalR will add this as ?access_token=xxx to the negotiation request
accessTokenFactory: () => {
const token = cookies.get('token')
if (!token) {
console.error('Token not available in accessTokenFactory')
throw new Error('Token expired or not available')
}
console.log('Providing token for SignalR connection')
return token
}
})
.withAutomaticReconnect({
nextRetryDelayInMilliseconds: (retryContext) => {
// Exponential backoff: 0s, 2s, 10s, 30s
if (retryContext.previousRetryCount === 0) return 2000
if (retryContext.previousRetryCount === 1) return 10000
return 30000
}
})
.build()
// Add connection event handlers for debugging
connection.onclose((error) => {
console.log('SignalR connection closed', error)
this.hubConnection = null
})
connection.onreconnecting((error) => {
console.log('SignalR reconnecting', error)
})
connection.onreconnected((connectionId) => {
console.log('SignalR reconnected', connectionId)
})
try {
console.log('Starting SignalR connection...')
await connection.start()
console.log('SignalR connected successfully. Connection ID:', connection.connectionId)
this.hubConnection = connection
return connection
} catch (error: any) {
console.error('Failed to connect to SignalR hub:', error)
console.error('Error details:', {
message: error?.message,
stack: error?.stack,
hubUrl: hubUrl,
hasToken: !!bearerToken
})
// Clean up on failure
this.hubConnection = null
throw new Error(`Failed to connect to SignalR hub: ${error?.message || 'Unknown error'}. Check browser console for details.`)
}
}
/**
* Disconnects from SignalR hub
*/
async disconnectFromHub(): Promise<void> {
if (this.hubConnection) {
await this.hubConnection.stop()
this.hubConnection = null
}
} }
/** /**
@@ -143,7 +49,7 @@ export class AiChatService {
} }
/** /**
* Send a chat message with streaming progress updates via SignalR * Send a chat message with streaming progress updates via SSE/Redis
* Returns an async generator that yields progress updates in real-time * Returns an async generator that yields progress updates in real-time
*/ */
async *sendMessageStream( async *sendMessageStream(
@@ -151,67 +57,140 @@ export class AiChatService {
provider?: string, provider?: string,
apiKey?: string apiKey?: string
): AsyncGenerator<LlmProgressUpdate, void, unknown> { ): AsyncGenerator<LlmProgressUpdate, void, unknown> {
// Connect to SignalR hub const cookies = new Cookies()
const connection = await this.connectToHub() const bearerToken = cookies.get('token')
const connectionId = connection.connectionId
if (!connectionId) { if (!bearerToken) {
yield { yield {
type: 'error', type: 'error',
message: 'Failed to get SignalR connection ID', message: 'No authentication token found',
error: 'Connection ID not available' error: 'Please log in first'
} }
return return
} }
const request = { // Generate unique stream ID
messages, const streamId = this.generateStreamId()
provider: provider || 'auto',
apiKey: apiKey, // Ensure baseUrl doesn't have trailing slash
stream: true, const baseUrl = this.baseUrl.endsWith('/') ? this.baseUrl.slice(0, -1) : this.baseUrl
temperature: 0.7, const streamUrl = `${baseUrl}/Llm/stream/${streamId}?access_token=${encodeURIComponent(bearerToken)}`
maxTokens: 4096,
tools: undefined, // Will be populated by backend console.log('Opening SSE connection:', streamUrl)
connectionId: connectionId
}
// Queue for incoming updates // Queue for incoming updates
const updateQueue: LlmProgressUpdate[] = [] const updateQueue: LlmProgressUpdate[] = []
let isComplete = false let isComplete = false
let resolver: ((update: LlmProgressUpdate) => void) | null = null let resolver: ((update: LlmProgressUpdate) => void) | null = null
let eventSource: EventSource | null = null
// Set up progress update handler // Set up SSE connection
const handler = (update: LlmProgressUpdate) => {
if (resolver) {
resolver(update)
resolver = null
} else {
updateQueue.push(update)
}
if (update.type === 'final_response' || update.type === 'error') {
isComplete = true
}
}
connection.on('ProgressUpdate', handler)
try { try {
// Send chat request to backend eventSource = new EventSource(streamUrl)
const cookies = new Cookies()
const bearerToken = cookies.get('token')
const response = await fetch(`${this.baseUrl}/Llm/ChatStream`, { eventSource.onmessage = (event) => {
try {
const rawUpdate = JSON.parse(event.data)
// Normalize PascalCase from backend to camelCase for frontend
let normalizedResponse: LlmChatResponse | undefined
if (rawUpdate.Response || rawUpdate.response) {
const rawResponse = rawUpdate.Response || rawUpdate.response
normalizedResponse = {
content: rawResponse.Content || rawResponse.content || '',
provider: rawResponse.Provider || rawResponse.provider,
model: rawResponse.Model || rawResponse.model,
toolCalls: rawResponse.ToolCalls || rawResponse.toolCalls,
usage: rawResponse.Usage || rawResponse.usage ? {
promptTokens: rawResponse.Usage?.PromptTokens ?? rawResponse.usage?.promptTokens ?? 0,
completionTokens: rawResponse.Usage?.CompletionTokens ?? rawResponse.usage?.completionTokens ?? 0,
totalTokens: rawResponse.Usage?.TotalTokens ?? rawResponse.usage?.totalTokens ?? 0
} : undefined,
requiresToolExecution: rawResponse.RequiresToolExecution ?? rawResponse.requiresToolExecution ?? false
}
}
const update: LlmProgressUpdate = {
type: rawUpdate.Type || rawUpdate.type || '',
message: rawUpdate.Message || rawUpdate.message || '',
iteration: rawUpdate.Iteration ?? rawUpdate.iteration,
maxIterations: rawUpdate.MaxIterations ?? rawUpdate.maxIterations,
toolName: rawUpdate.ToolName || rawUpdate.toolName,
toolArguments: rawUpdate.ToolArguments || rawUpdate.toolArguments,
content: rawUpdate.Content || rawUpdate.content,
response: normalizedResponse,
error: rawUpdate.Error || rawUpdate.error,
timestamp: rawUpdate.Timestamp ? new Date(rawUpdate.Timestamp) : rawUpdate.timestamp
}
// Skip "connected" messages as they're just connection confirmations
if (update.type === 'connected') {
return
}
if (resolver) {
resolver(update)
resolver = null
} else {
updateQueue.push(update)
}
if (update.type === 'final_response' || update.type === 'error') {
isComplete = true
eventSource?.close()
}
} catch (e) {
console.error('Error parsing SSE message:', e)
}
}
eventSource.onerror = (error) => {
console.error('SSE connection error:', error)
if (resolver) {
resolver({
type: 'error',
message: 'SSE connection error',
error: 'Connection failed'
})
resolver = null
} else {
updateQueue.push({
type: 'error',
message: 'SSE connection error',
error: 'Connection failed'
})
}
isComplete = true
eventSource?.close()
}
// Wait a bit for connection to establish
await new Promise(resolve => setTimeout(resolve, 100))
// Send chat request to backend
const request = {
messages,
provider: provider || 'auto',
apiKey: apiKey,
stream: true,
temperature: 0.7,
maxTokens: 4096,
tools: undefined, // Will be populated by backend
streamId: streamId
}
const response = await fetch(`${baseUrl}/Llm/ChatStream`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
...(bearerToken ? { Authorization: `Bearer ${bearerToken}` } : {}) Authorization: `Bearer ${bearerToken}`
}, },
body: JSON.stringify(request) body: JSON.stringify(request)
}) })
if (!response.ok) { if (!response.ok) {
const errorText = await response.text() const errorText = await response.text()
eventSource?.close()
yield { yield {
type: 'error', type: 'error',
message: `HTTP ${response.status}: ${errorText}`, message: `HTTP ${response.status}: ${errorText}`,
@@ -220,7 +199,7 @@ export class AiChatService {
return return
} }
// Yield updates as they arrive via SignalR // Yield updates as they arrive via SSE
while (!isComplete) { while (!isComplete) {
// Check if we have queued updates // Check if we have queued updates
if (updateQueue.length > 0) { if (updateQueue.length > 0) {
@@ -253,8 +232,11 @@ export class AiChatService {
error: error.message error: error.message
} }
} finally { } finally {
// Clean up handler // Clean up SSE connection
connection.off('ProgressUpdate', handler) if (eventSource) {
eventSource.close()
console.log('SSE connection closed')
}
} }
} }