From 48fedb124710ccd7d46a8e284b32da15e3a8c19d Mon Sep 17 00:00:00 2001 From: cryptooda Date: Wed, 7 Jan 2026 18:13:18 +0700 Subject: [PATCH] Refactor LlmController and AiChatService for SSE integration and Redis support - Updated LlmController to implement a new SSE endpoint for streaming LLM progress updates, utilizing Redis pub/sub for real-time communication. - Removed SignalR dependencies from AiChatService, replacing them with SSE logic for message streaming. - Enhanced error handling and logging for Redis interactions, ensuring robust feedback during streaming operations. - Adjusted request models and methods to accommodate the new streaming architecture, improving clarity and maintainability. --- docs/TASK_ENVIRONMENTS_SETUP.md | 1 + src/Managing.Api/Controllers/LlmController.cs | 219 ++++++++++----- src/Managing.Api/Program.cs | 16 +- .../Evm/PrivyRevokeAllApprovalsResponse.cs | 1 + .../src/services/aiChatService.ts | 258 ++++++++---------- 5 files changed, 287 insertions(+), 208 deletions(-) diff --git a/docs/TASK_ENVIRONMENTS_SETUP.md b/docs/TASK_ENVIRONMENTS_SETUP.md index 6393023c..e2994bdd 100644 --- a/docs/TASK_ENVIRONMENTS_SETUP.md +++ b/docs/TASK_ENVIRONMENTS_SETUP.md @@ -178,3 +178,4 @@ Task Environment (offset ports) + diff --git a/src/Managing.Api/Controllers/LlmController.cs b/src/Managing.Api/Controllers/LlmController.cs index 12993100..e8e34635 100644 --- a/src/Managing.Api/Controllers/LlmController.cs +++ b/src/Managing.Api/Controllers/LlmController.cs @@ -2,13 +2,12 @@ using System.Net.Http; using System.Text.Json; using System.Text.RegularExpressions; using Managing.Application.Abstractions.Services; -using Managing.Application.Hubs; using Managing.Domain.Users; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.DependencyInjection; +using StackExchange.Redis; namespace Managing.Api.Controllers; @@ -26,8 +25,8 @@ public class LlmController : BaseController private readonly IMcpService _mcpService; private readonly ILogger _logger; private readonly IMemoryCache _cache; - private readonly IHubContext _hubContext; private readonly IServiceScopeFactory _serviceScopeFactory; + private readonly IRedisConnectionService _redisService; public LlmController( ILlmService llmService, @@ -35,38 +34,120 @@ public class LlmController : BaseController IUserService userService, ILogger logger, IMemoryCache cache, - IHubContext hubContext, - IServiceScopeFactory serviceScopeFactory) : base(userService) + IServiceScopeFactory serviceScopeFactory, + IRedisConnectionService redisService) : base(userService) { _llmService = llmService; _mcpService = mcpService; _logger = logger; _cache = cache; - _hubContext = hubContext; _serviceScopeFactory = serviceScopeFactory; + _redisService = redisService; } /// - /// Sends a chat message to an LLM with streaming progress updates via SignalR. - /// Provides real-time updates about iterations, tool calls, and progress similar to Cursor/Claude. - /// Progress updates are sent via SignalR to the specified connectionId. + /// SSE endpoint for streaming LLM progress updates. + /// Subscribes to Redis pub/sub channel for the given streamId and streams updates to the client. /// - /// The chat request with messages, optional provider/API key, and SignalR connectionId - /// OK status - updates are sent via SignalR + /// Unique stream identifier + /// SSE stream of progress updates + [HttpGet] + [Route("stream/{streamId}")] + [Produces("text/event-stream")] + public async Task StreamUpdates(string streamId, CancellationToken cancellationToken) + { + // Verify user is authenticated + try + { + var user = await GetUser(); + if (user == null) + { + Response.StatusCode = 401; + await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Unauthorized\",\"error\":\"Authentication required\"}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + return; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error authenticating user for SSE stream {StreamId}", streamId); + Response.StatusCode = 401; + await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Authentication failed\",\"error\":\"Unable to authenticate user\"}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + return; + } + + Response.ContentType = "text/event-stream"; + Response.Headers["Cache-Control"] = "no-cache"; + Response.Headers["Connection"] = "keep-alive"; + Response.Headers["X-Accel-Buffering"] = "no"; // Disable nginx buffering + + var redis = _redisService.GetConnection(); + if (redis == null) + { + await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Redis unavailable\"}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + return; + } + + var subscriber = redis.GetSubscriber(); + var channel = RedisChannel.Literal($"llm-stream:{streamId}"); + + // Subscribe to updates for this stream + await subscriber.SubscribeAsync(channel, async (redisChannel, message) => + { + try + { + var json = message.ToString(); + await Response.WriteAsync($"data: {json}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error writing SSE message for stream {StreamId}", streamId); + } + }); + + // Send initial connection message + await Response.WriteAsync($"data: {{\"type\":\"connected\",\"streamId\":\"{streamId}\"}}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + + _logger.LogInformation("SSE connection established for stream {StreamId}", streamId); + + // Keep connection alive until cancelled + try + { + await Task.Delay(Timeout.Infinite, cancellationToken); + } + catch (OperationCanceledException) + { + // Client disconnected + await subscriber.UnsubscribeAsync(channel); + _logger.LogInformation("SSE connection closed for stream {StreamId}", streamId); + } + } + + /// + /// Sends a chat message to an LLM with streaming progress updates via SSE/Redis. + /// Provides real-time updates about iterations, tool calls, and progress similar to Cursor/Claude. + /// Progress updates are published to Redis pub/sub and streamed via SSE. + /// + /// The chat request with messages, optional provider/API key, and streamId + /// OK status - updates are sent via SSE [HttpPost] [Route("ChatStream")] [Consumes("application/json")] [Produces("application/json")] public async Task ChatStream([FromBody] LlmChatStreamRequest request) { - if (request == null || string.IsNullOrWhiteSpace(request.ConnectionId)) + if (request == null || string.IsNullOrWhiteSpace(request.StreamId)) { - return BadRequest("Chat request and connectionId are required"); + return BadRequest("Chat request and streamId are required"); } if (request.Messages == null || !request.Messages.Any()) { - await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate + await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate { Type = "error", Message = "At least one message is required", @@ -87,7 +168,7 @@ public class LlmController : BaseController if (user == null) { - await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", new LlmProgressUpdate + await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate { Type = "error", Message = "Error authenticating user", @@ -109,58 +190,56 @@ public class LlmController : BaseController var mcpService = scope.ServiceProvider.GetRequiredService(); var userService = scope.ServiceProvider.GetRequiredService(); var cache = scope.ServiceProvider.GetRequiredService(); - var hubContext = scope.ServiceProvider.GetRequiredService>(); + var redisService = scope.ServiceProvider.GetRequiredService(); var logger = scope.ServiceProvider.GetRequiredService>(); // Reload user from the scoped service to ensure we have a valid user object var scopedUser = await userService.GetUserByIdAsync(userId); if (scopedUser == null) { - await hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", - new LlmProgressUpdate - { - Type = "error", - Message = "User not found", - Error = "Unable to authenticate user" - }); + await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate + { + Type = "error", + Message = "User not found", + Error = "Unable to authenticate user" + }); return; } - await ChatStreamInternal(request, scopedUser, request.ConnectionId, llmService, mcpService, cache, - hubContext, logger); + await ChatStreamInternal(request, scopedUser, request.StreamId, llmService, mcpService, cache, + redisService, logger); } catch (Exception ex) { - _logger.LogError(ex, "Error processing chat stream for connection {ConnectionId}", - request.ConnectionId); + _logger.LogError(ex, "Error processing chat stream for stream {StreamId}", + request.StreamId); try { - await _hubContext.Clients.Client(request.ConnectionId).SendAsync("ProgressUpdate", - new LlmProgressUpdate - { - Type = "error", - Message = $"Error processing chat: {ex.Message}", - Error = ex.Message - }); + await PublishProgressUpdate(request.StreamId, new LlmProgressUpdate + { + Type = "error", + Message = $"Error processing chat: {ex.Message}", + Error = ex.Message + }); } - catch (Exception hubEx) + catch (Exception redisEx) { - _logger.LogError(hubEx, "Error sending error message to SignalR client"); + _logger.LogError(redisEx, "Error publishing error message to Redis"); } } }); - return Ok(new { Message = "Chat stream started", ConnectionId = request.ConnectionId }); + return Ok(new { Message = "Chat stream started", StreamId = request.StreamId }); } private async Task ChatStreamInternal( LlmChatStreamRequest request, User user, - string connectionId, + string streamId, ILlmService llmService, IMcpService mcpService, IMemoryCache cache, - IHubContext hubContext, + IRedisConnectionService redisService, ILogger logger) { // Convert to LlmChatRequest for service calls @@ -175,7 +254,7 @@ public class LlmController : BaseController Tools = request.Tools }; - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Initializing conversation and loading available tools..." @@ -189,7 +268,7 @@ public class LlmController : BaseController }); chatRequest.Tools = availableTools; - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = $"Loaded {availableTools.Count} available tools. Preparing system context..." @@ -225,7 +304,7 @@ public class LlmController : BaseController const int DelayAfterToolCallsMs = 1000; // Additional delay after tool calls before next LLM call const int MaxRedundantDetections = 2; // Maximum times we'll detect redundant calls before forcing final response - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = $"Starting analysis (up to {maxIterations} iterations may be needed)..." @@ -238,7 +317,7 @@ public class LlmController : BaseController // Get the last user question once per iteration to avoid scope conflicts var lastUserQuestion = chatRequest.Messages.LastOrDefault(m => m.Role == "user")?.Content ?? "the user's question"; - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "iteration_start", Message = "Analyzing your request and determining next steps...", @@ -252,7 +331,7 @@ public class LlmController : BaseController // Add delay between iterations to avoid rapid bursts and rate limiting if (iteration > 1) { - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Waiting briefly to respect rate limits...", @@ -266,7 +345,7 @@ public class LlmController : BaseController TrimConversationContext(chatRequest); // Send chat request to LLM with retry logic for rate limits - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Sending request to LLM...", @@ -286,7 +365,7 @@ public class LlmController : BaseController // Rate limit hit - wait longer before retrying logger.LogWarning("Rate limit hit (429) in iteration {Iteration}. Waiting 10 seconds before retry...", iteration); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Rate limit reached. Waiting before retrying...", @@ -317,7 +396,7 @@ public class LlmController : BaseController logger.LogInformation("LLM provided final answer after {Iteration} iteration(s) for user {UserId}", iteration, user.Id); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Received final response. Preparing answer...", @@ -336,7 +415,7 @@ public class LlmController : BaseController logger.LogWarning("LLM requested {Count} redundant tool calls in iteration {Iteration}: {Tools} (Detection #{DetectionCount})", redundantCalls.Count, iteration, string.Join(", ", redundantCalls.Select(t => t.Name)), redundantCallDetections); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Detected redundant tool calls. Using cached data...", @@ -350,7 +429,7 @@ public class LlmController : BaseController logger.LogWarning("Reached maximum redundant call detections ({MaxDetections}). Removing tools to force final response.", MaxRedundantDetections); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Multiple redundant tool calls detected. Forcing final response...", @@ -433,7 +512,7 @@ public class LlmController : BaseController logger.LogInformation("LLM requested {Count} tool calls in iteration {Iteration} for user {UserId}", response.ToolCalls.Count, iteration, user.Id); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = $"LLM requested {response.ToolCalls.Count} tool call(s). Executing tools...", @@ -445,7 +524,7 @@ public class LlmController : BaseController var toolResults = new List(); foreach (var toolCall in response.ToolCalls) { - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "tool_call", Message = $"Calling tool: {toolCall.Name}", @@ -465,7 +544,7 @@ public class LlmController : BaseController toolCall.Name, iteration, user.Id); var resultMessage = GenerateToolResultMessage(toolCall.Name, result); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "tool_result", Message = resultMessage, @@ -487,7 +566,7 @@ public class LlmController : BaseController "Error executing tool {ToolName} in iteration {Iteration} for user {UserId}: {Error}", toolCall.Name, iteration, user.Id, error); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "tool_result", Message = $"Tool {toolCall.Name} encountered an error: {error}", @@ -506,7 +585,7 @@ public class LlmController : BaseController } } - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "All tools completed. Analyzing results...", @@ -546,7 +625,7 @@ public class LlmController : BaseController "Reached max iterations ({MaxIterations}) for user {UserId}. Forcing final response without tools.", maxIterations, user.Id); - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "thinking", Message = "Reached maximum iterations. Preparing final response with available data...", @@ -617,7 +696,7 @@ public class LlmController : BaseController user.Id, iteration, finalResponse.Content?.Length ?? 0, !string.IsNullOrWhiteSpace(finalResponse.Content)); // Send final response - await SendProgressUpdate(connectionId, hubContext, logger, new LlmProgressUpdate + await PublishProgressUpdate(streamId, new LlmProgressUpdate { Type = "final_response", Message = "Analysis complete!", @@ -626,7 +705,7 @@ public class LlmController : BaseController MaxIterations = maxIterations }); - logger.LogInformation("Final response sent successfully to connection {ConnectionId} for user {UserId}", connectionId, user.Id); + logger.LogInformation("Final response sent successfully to stream {StreamId} for user {UserId}", streamId, user.Id); } /// @@ -1562,18 +1641,28 @@ public class LlmController : BaseController } /// - /// Helper method to send progress update via SignalR + /// Helper method to publish progress update to Redis pub/sub /// - private async Task SendProgressUpdate(string connectionId, IHubContext hubContext, - ILogger logger, LlmProgressUpdate update) + private async Task PublishProgressUpdate(string streamId, LlmProgressUpdate update) { try { - await hubContext.Clients.Client(connectionId).SendAsync("ProgressUpdate", update); + var redis = _redisService.GetConnection(); + if (redis == null) + { + _logger.LogWarning("Redis not available, cannot publish progress update for stream {StreamId}", streamId); + return; + } + + var subscriber = redis.GetSubscriber(); + var channel = RedisChannel.Literal($"llm-stream:{streamId}"); + var message = JsonSerializer.Serialize(update); + + await subscriber.PublishAsync(channel, message); } catch (Exception ex) { - logger.LogError(ex, "Error sending progress update to connection {ConnectionId}", connectionId); + _logger.LogError(ex, "Error publishing progress update to Redis for stream {StreamId}", streamId); } } @@ -1653,12 +1742,12 @@ public class LlmController : BaseController } /// -/// Request model for LLM chat streaming via SignalR +/// Request model for LLM chat streaming via SSE/Redis /// public class LlmChatStreamRequest : LlmChatRequest { /// - /// SignalR connection ID to send progress updates to + /// Stream ID for SSE connection and Redis pub/sub channel /// - public string ConnectionId { get; set; } = string.Empty; + public string StreamId { get; set; } = string.Empty; } \ No newline at end of file diff --git a/src/Managing.Api/Program.cs b/src/Managing.Api/Program.cs index db271243..4a0920dc 100644 --- a/src/Managing.Api/Program.cs +++ b/src/Managing.Api/Program.cs @@ -306,20 +306,26 @@ builder.Services { OnMessageReceived = context => { + var path = context.Request.Path.Value ?? ""; + var pathLower = path.ToLowerInvariant(); + // Skip token extraction for anonymous endpoints to avoid validation errors - var path = context.Request.Path.Value?.ToLower() ?? ""; - if (!string.IsNullOrEmpty(path) && (path.EndsWith("/create-token") || path.EndsWith("/authenticate"))) + if (!string.IsNullOrEmpty(pathLower) && (pathLower.EndsWith("/create-token") || pathLower.EndsWith("/authenticate"))) { // Clear any token to prevent validation on anonymous endpoints context.Token = null; return Task.CompletedTask; } - // Extract token from query string for SignalR connections + // Extract token from query string for SignalR connections and SSE endpoints // SignalR uses access_token query parameter for WebSocket connections - if (path.Contains("/bothub") || path.Contains("/backtesthub") || path.Contains("/llmhub")) + // SSE endpoints also use access_token since EventSource doesn't support custom headers + if (pathLower.Contains("/bothub") || + pathLower.Contains("/backtesthub") || + pathLower.Contains("/llmhub") || + pathLower.Contains("/llm/stream")) { - var accessToken = context.Request.Query["access_token"]; + var accessToken = context.Request.Query["access_token"].FirstOrDefault(); if (!string.IsNullOrEmpty(accessToken)) { context.Token = accessToken; diff --git a/src/Managing.Domain/Evm/PrivyRevokeAllApprovalsResponse.cs b/src/Managing.Domain/Evm/PrivyRevokeAllApprovalsResponse.cs index c7f366a2..efd8276c 100644 --- a/src/Managing.Domain/Evm/PrivyRevokeAllApprovalsResponse.cs +++ b/src/Managing.Domain/Evm/PrivyRevokeAllApprovalsResponse.cs @@ -16,3 +16,4 @@ public class PrivyRevokeAllApprovalsResponse public string? Error { get; set; } } + diff --git a/src/Managing.WebApp/src/services/aiChatService.ts b/src/Managing.WebApp/src/services/aiChatService.ts index 23a918ea..914e4fba 100644 --- a/src/Managing.WebApp/src/services/aiChatService.ts +++ b/src/Managing.WebApp/src/services/aiChatService.ts @@ -1,4 +1,3 @@ -import { HubConnection, HubConnectionBuilder } from '@microsoft/signalr' import { LlmClient } from '../generated/ManagingApi' import { LlmChatRequest, LlmChatResponse, LlmMessage } from '../generated/ManagingApiTypes' import { Cookies } from 'react-cookie' @@ -19,7 +18,6 @@ export interface LlmProgressUpdate { export class AiChatService { private llmClient: LlmClient private baseUrl: string - private hubConnection: HubConnection | null = null constructor(llmClient: LlmClient, baseUrl: string) { this.llmClient = llmClient @@ -27,102 +25,10 @@ export class AiChatService { } /** - * Creates and connects to SignalR hub for LLM chat streaming + * Generates a unique stream ID for SSE connection */ - async connectToHub(): Promise { - if (this.hubConnection?.state === 'Connected') { - return this.hubConnection - } - - // Clean up existing connection if any - if (this.hubConnection) { - try { - await this.hubConnection.stop() - } catch (e) { - // Ignore stop errors - } - this.hubConnection = null - } - - const cookies = new Cookies() - const bearerToken = cookies.get('token') - - if (!bearerToken) { - throw new Error('No authentication token found. Please log in first.') - } - - // Ensure baseUrl doesn't have trailing slash - const baseUrl = this.baseUrl.endsWith('/') ? this.baseUrl.slice(0, -1) : this.baseUrl - const hubUrl = `${baseUrl}/llmhub` - - console.log('Connecting to SignalR hub:', hubUrl) - - const connection = new HubConnectionBuilder() - .withUrl(hubUrl, { - // Pass token via query string (standard for SignalR WebSocket connections) - // SignalR will add this as ?access_token=xxx to the negotiation request - accessTokenFactory: () => { - const token = cookies.get('token') - if (!token) { - console.error('Token not available in accessTokenFactory') - throw new Error('Token expired or not available') - } - console.log('Providing token for SignalR connection') - return token - } - }) - .withAutomaticReconnect({ - nextRetryDelayInMilliseconds: (retryContext) => { - // Exponential backoff: 0s, 2s, 10s, 30s - if (retryContext.previousRetryCount === 0) return 2000 - if (retryContext.previousRetryCount === 1) return 10000 - return 30000 - } - }) - .build() - - // Add connection event handlers for debugging - connection.onclose((error) => { - console.log('SignalR connection closed', error) - this.hubConnection = null - }) - - connection.onreconnecting((error) => { - console.log('SignalR reconnecting', error) - }) - - connection.onreconnected((connectionId) => { - console.log('SignalR reconnected', connectionId) - }) - - try { - console.log('Starting SignalR connection...') - await connection.start() - console.log('SignalR connected successfully. Connection ID:', connection.connectionId) - this.hubConnection = connection - return connection - } catch (error: any) { - console.error('Failed to connect to SignalR hub:', error) - console.error('Error details:', { - message: error?.message, - stack: error?.stack, - hubUrl: hubUrl, - hasToken: !!bearerToken - }) - // Clean up on failure - this.hubConnection = null - throw new Error(`Failed to connect to SignalR hub: ${error?.message || 'Unknown error'}. Check browser console for details.`) - } - } - - /** - * Disconnects from SignalR hub - */ - async disconnectFromHub(): Promise { - if (this.hubConnection) { - await this.hubConnection.stop() - this.hubConnection = null - } + private generateStreamId(): string { + return `stream-${Date.now()}-${Math.random().toString(36).substr(2, 9)}` } /** @@ -143,7 +49,7 @@ export class AiChatService { } /** - * Send a chat message with streaming progress updates via SignalR + * Send a chat message with streaming progress updates via SSE/Redis * Returns an async generator that yields progress updates in real-time */ async *sendMessageStream( @@ -151,67 +57,140 @@ export class AiChatService { provider?: string, apiKey?: string ): AsyncGenerator { - // Connect to SignalR hub - const connection = await this.connectToHub() - const connectionId = connection.connectionId + const cookies = new Cookies() + const bearerToken = cookies.get('token') - if (!connectionId) { + if (!bearerToken) { yield { type: 'error', - message: 'Failed to get SignalR connection ID', - error: 'Connection ID not available' + message: 'No authentication token found', + error: 'Please log in first' } return } - const request = { - messages, - provider: provider || 'auto', - apiKey: apiKey, - stream: true, - temperature: 0.7, - maxTokens: 4096, - tools: undefined, // Will be populated by backend - connectionId: connectionId - } + // Generate unique stream ID + const streamId = this.generateStreamId() + + // Ensure baseUrl doesn't have trailing slash + const baseUrl = this.baseUrl.endsWith('/') ? this.baseUrl.slice(0, -1) : this.baseUrl + const streamUrl = `${baseUrl}/Llm/stream/${streamId}?access_token=${encodeURIComponent(bearerToken)}` + + console.log('Opening SSE connection:', streamUrl) // Queue for incoming updates const updateQueue: LlmProgressUpdate[] = [] let isComplete = false let resolver: ((update: LlmProgressUpdate) => void) | null = null + let eventSource: EventSource | null = null - // Set up progress update handler - const handler = (update: LlmProgressUpdate) => { - if (resolver) { - resolver(update) - resolver = null - } else { - updateQueue.push(update) - } - - if (update.type === 'final_response' || update.type === 'error') { - isComplete = true - } - } - - connection.on('ProgressUpdate', handler) - + // Set up SSE connection try { - // Send chat request to backend - const cookies = new Cookies() - const bearerToken = cookies.get('token') + eventSource = new EventSource(streamUrl) - const response = await fetch(`${this.baseUrl}/Llm/ChatStream`, { + eventSource.onmessage = (event) => { + try { + const rawUpdate = JSON.parse(event.data) + + // Normalize PascalCase from backend to camelCase for frontend + let normalizedResponse: LlmChatResponse | undefined + if (rawUpdate.Response || rawUpdate.response) { + const rawResponse = rawUpdate.Response || rawUpdate.response + normalizedResponse = { + content: rawResponse.Content || rawResponse.content || '', + provider: rawResponse.Provider || rawResponse.provider, + model: rawResponse.Model || rawResponse.model, + toolCalls: rawResponse.ToolCalls || rawResponse.toolCalls, + usage: rawResponse.Usage || rawResponse.usage ? { + promptTokens: rawResponse.Usage?.PromptTokens ?? rawResponse.usage?.promptTokens ?? 0, + completionTokens: rawResponse.Usage?.CompletionTokens ?? rawResponse.usage?.completionTokens ?? 0, + totalTokens: rawResponse.Usage?.TotalTokens ?? rawResponse.usage?.totalTokens ?? 0 + } : undefined, + requiresToolExecution: rawResponse.RequiresToolExecution ?? rawResponse.requiresToolExecution ?? false + } + } + + const update: LlmProgressUpdate = { + type: rawUpdate.Type || rawUpdate.type || '', + message: rawUpdate.Message || rawUpdate.message || '', + iteration: rawUpdate.Iteration ?? rawUpdate.iteration, + maxIterations: rawUpdate.MaxIterations ?? rawUpdate.maxIterations, + toolName: rawUpdate.ToolName || rawUpdate.toolName, + toolArguments: rawUpdate.ToolArguments || rawUpdate.toolArguments, + content: rawUpdate.Content || rawUpdate.content, + response: normalizedResponse, + error: rawUpdate.Error || rawUpdate.error, + timestamp: rawUpdate.Timestamp ? new Date(rawUpdate.Timestamp) : rawUpdate.timestamp + } + + // Skip "connected" messages as they're just connection confirmations + if (update.type === 'connected') { + return + } + + if (resolver) { + resolver(update) + resolver = null + } else { + updateQueue.push(update) + } + + if (update.type === 'final_response' || update.type === 'error') { + isComplete = true + eventSource?.close() + } + } catch (e) { + console.error('Error parsing SSE message:', e) + } + } + + eventSource.onerror = (error) => { + console.error('SSE connection error:', error) + if (resolver) { + resolver({ + type: 'error', + message: 'SSE connection error', + error: 'Connection failed' + }) + resolver = null + } else { + updateQueue.push({ + type: 'error', + message: 'SSE connection error', + error: 'Connection failed' + }) + } + isComplete = true + eventSource?.close() + } + + // Wait a bit for connection to establish + await new Promise(resolve => setTimeout(resolve, 100)) + + // Send chat request to backend + const request = { + messages, + provider: provider || 'auto', + apiKey: apiKey, + stream: true, + temperature: 0.7, + maxTokens: 4096, + tools: undefined, // Will be populated by backend + streamId: streamId + } + + const response = await fetch(`${baseUrl}/Llm/ChatStream`, { method: 'POST', headers: { 'Content-Type': 'application/json', - ...(bearerToken ? { Authorization: `Bearer ${bearerToken}` } : {}) + Authorization: `Bearer ${bearerToken}` }, body: JSON.stringify(request) }) if (!response.ok) { const errorText = await response.text() + eventSource?.close() yield { type: 'error', message: `HTTP ${response.status}: ${errorText}`, @@ -220,7 +199,7 @@ export class AiChatService { return } - // Yield updates as they arrive via SignalR + // Yield updates as they arrive via SSE while (!isComplete) { // Check if we have queued updates if (updateQueue.length > 0) { @@ -253,8 +232,11 @@ export class AiChatService { error: error.message } } finally { - // Clean up handler - connection.off('ProgressUpdate', handler) + // Clean up SSE connection + if (eventSource) { + eventSource.close() + console.log('SSE connection closed') + } } }