Implement LLM provider configuration and update user settings
- Added functionality to update the default LLM provider for users via a new endpoint in UserController. - Introduced LlmProvider enum to manage available LLM options: Auto, Gemini, OpenAI, and Claude. - Updated User and UserEntity models to include DefaultLlmProvider property. - Enhanced database context and migrations to support the new LLM provider configuration. - Integrated LLM services into the application bootstrap for dependency injection. - Updated TypeScript API client to include methods for managing LLM providers and chat requests.
This commit is contained in:
210
src/Managing.Application/LLM/LlmService.cs
Normal file
210
src/Managing.Application/LLM/LlmService.cs
Normal file
@@ -0,0 +1,210 @@
|
||||
using Managing.Application.Abstractions.Services;
|
||||
using Managing.Application.LLM.Providers;
|
||||
using Managing.Domain.Users;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using static Managing.Common.Enums;
|
||||
|
||||
namespace Managing.Application.LLM;
|
||||
|
||||
/// <summary>
|
||||
/// Service for interacting with LLM providers with auto-selection and BYOK support
|
||||
/// </summary>
|
||||
public class LlmService : ILlmService
|
||||
{
|
||||
private readonly IConfiguration _configuration;
|
||||
private readonly ILogger<LlmService> _logger;
|
||||
private readonly Dictionary<string, ILlmProvider> _providers;
|
||||
|
||||
public LlmService(
|
||||
IConfiguration configuration,
|
||||
ILogger<LlmService> logger,
|
||||
IHttpClientFactory httpClientFactory)
|
||||
{
|
||||
_configuration = configuration;
|
||||
_logger = logger;
|
||||
_providers = new Dictionary<string, ILlmProvider>(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
// Initialize providers
|
||||
InitializeProviders(httpClientFactory);
|
||||
}
|
||||
|
||||
private void InitializeProviders(IHttpClientFactory httpClientFactory)
|
||||
{
|
||||
// Gemini Provider
|
||||
var geminiApiKey = _configuration["Llm:Gemini:ApiKey"];
|
||||
var geminiModel = _configuration["Llm:Gemini:DefaultModel"];
|
||||
if (!string.IsNullOrWhiteSpace(geminiApiKey))
|
||||
{
|
||||
var providerKey = ConvertLlmProviderToString(LlmProvider.Gemini);
|
||||
_providers[providerKey] = new GeminiProvider(geminiApiKey, geminiModel, httpClientFactory, _logger);
|
||||
_logger.LogInformation("Gemini provider initialized with model: {Model}", geminiModel ?? "default");
|
||||
}
|
||||
|
||||
// OpenAI Provider
|
||||
var openaiApiKey = _configuration["Llm:OpenAI:ApiKey"];
|
||||
var openaiModel = _configuration["Llm:OpenAI:DefaultModel"];
|
||||
if (!string.IsNullOrWhiteSpace(openaiApiKey))
|
||||
{
|
||||
var providerKey = ConvertLlmProviderToString(LlmProvider.OpenAI);
|
||||
_providers[providerKey] = new OpenAiProvider(openaiApiKey, openaiModel, httpClientFactory, _logger);
|
||||
_logger.LogInformation("OpenAI provider initialized with model: {Model}", openaiModel ?? "default");
|
||||
}
|
||||
|
||||
// Claude Provider
|
||||
var claudeApiKey = _configuration["Llm:Claude:ApiKey"];
|
||||
var claudeModel = _configuration["Llm:Claude:DefaultModel"];
|
||||
if (!string.IsNullOrWhiteSpace(claudeApiKey))
|
||||
{
|
||||
var providerKey = ConvertLlmProviderToString(LlmProvider.Claude);
|
||||
_providers[providerKey] = new ClaudeProvider(claudeApiKey, claudeModel, httpClientFactory, _logger);
|
||||
_logger.LogInformation("Claude provider initialized with model: {Model}", claudeModel ?? "default");
|
||||
}
|
||||
|
||||
if (_providers.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No LLM providers configured. Please add API keys to configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<LlmChatResponse> ChatAsync(User user, LlmChatRequest request)
|
||||
{
|
||||
ILlmProvider provider;
|
||||
|
||||
// BYOK: If user provides their own API key
|
||||
if (!string.IsNullOrWhiteSpace(request.ApiKey))
|
||||
{
|
||||
var requestedProvider = ParseProviderString(request.Provider) ?? LlmProvider.Claude; // Default to Claude for BYOK
|
||||
var providerName = ConvertLlmProviderToString(requestedProvider);
|
||||
provider = CreateProviderWithCustomKey(requestedProvider, request.ApiKey);
|
||||
_logger.LogInformation("Using BYOK for provider: {Provider} for user: {UserId}", providerName, user.Id);
|
||||
}
|
||||
// Auto mode: Select provider automatically (use user's default if set, otherwise fallback to system default)
|
||||
else if (string.IsNullOrWhiteSpace(request.Provider) ||
|
||||
ParseProviderString(request.Provider) == LlmProvider.Auto)
|
||||
{
|
||||
// Check if user has a default provider preference (and it's not Auto)
|
||||
if (user.DefaultLlmProvider.HasValue &&
|
||||
user.DefaultLlmProvider.Value != LlmProvider.Auto)
|
||||
{
|
||||
var providerName = ConvertLlmProviderToString(user.DefaultLlmProvider.Value);
|
||||
if (_providers.TryGetValue(providerName, out var userPreferredProvider))
|
||||
{
|
||||
provider = userPreferredProvider;
|
||||
_logger.LogInformation("Using user's default provider: {Provider} for user: {UserId}", provider.Name, user.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
provider = SelectProvider();
|
||||
_logger.LogInformation("Auto-selected provider: {Provider} for user: {UserId} (user default {UserDefault} not available)",
|
||||
provider.Name, user.Id, user.DefaultLlmProvider.Value);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
provider = SelectProvider();
|
||||
_logger.LogInformation("Auto-selected provider: {Provider} for user: {UserId} (user default: {UserDefault})",
|
||||
provider.Name, user.Id, user.DefaultLlmProvider?.ToString() ?? "not set");
|
||||
}
|
||||
}
|
||||
// Explicit provider selection
|
||||
else
|
||||
{
|
||||
var requestedProvider = ParseProviderString(request.Provider);
|
||||
if (requestedProvider == null || requestedProvider == LlmProvider.Auto)
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid provider '{request.Provider}'. Valid providers are: {string.Join(", ", Enum.GetNames<LlmProvider>())}");
|
||||
}
|
||||
|
||||
var providerName = ConvertLlmProviderToString(requestedProvider.Value);
|
||||
if (!_providers.TryGetValue(providerName, out provider!))
|
||||
{
|
||||
throw new InvalidOperationException($"Provider '{request.Provider}' is not available or not configured.");
|
||||
}
|
||||
_logger.LogInformation("Using specified provider: {Provider} for user: {UserId}", providerName, user.Id);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var response = await provider.ChatAsync(request);
|
||||
return response;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error calling LLM provider {Provider} for user {UserId}", provider.Name, user.Id);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public Task<IEnumerable<string>> GetAvailableProvidersAsync()
|
||||
{
|
||||
return Task.FromResult(_providers.Keys.AsEnumerable());
|
||||
}
|
||||
|
||||
private ILlmProvider SelectProvider()
|
||||
{
|
||||
// Priority: OpenAI > Claude > Gemini
|
||||
var openaiKey = ConvertLlmProviderToString(LlmProvider.OpenAI);
|
||||
if (_providers.TryGetValue(openaiKey, out var openai))
|
||||
return openai;
|
||||
|
||||
var claudeKey = ConvertLlmProviderToString(LlmProvider.Claude);
|
||||
if (_providers.TryGetValue(claudeKey, out var claude))
|
||||
return claude;
|
||||
|
||||
var geminiKey = ConvertLlmProviderToString(LlmProvider.Gemini);
|
||||
if (_providers.TryGetValue(geminiKey, out var gemini))
|
||||
return gemini;
|
||||
|
||||
throw new InvalidOperationException("No LLM providers are configured. Please add API keys to configuration.");
|
||||
}
|
||||
|
||||
private ILlmProvider CreateProviderWithCustomKey(LlmProvider provider, string apiKey)
|
||||
{
|
||||
// This is a temporary instance with user's API key
|
||||
// Get default models from configuration
|
||||
var geminiModel = _configuration["Llm:Gemini:DefaultModel"];
|
||||
var openaiModel = _configuration["Llm:OpenAI:DefaultModel"];
|
||||
var claudeModel = _configuration["Llm:Claude:DefaultModel"];
|
||||
|
||||
return provider switch
|
||||
{
|
||||
LlmProvider.Gemini => new GeminiProvider(apiKey, geminiModel, null!, _logger),
|
||||
LlmProvider.OpenAI => new OpenAiProvider(apiKey, openaiModel, null!, _logger),
|
||||
LlmProvider.Claude => new ClaudeProvider(apiKey, claudeModel, null!, _logger),
|
||||
_ => throw new InvalidOperationException($"Cannot create provider with custom key for: {provider}. Only Gemini, OpenAI, and Claude are supported for BYOK.")
|
||||
};
|
||||
}
|
||||
|
||||
private string ConvertLlmProviderToString(LlmProvider provider)
|
||||
{
|
||||
return provider switch
|
||||
{
|
||||
LlmProvider.Auto => "auto",
|
||||
LlmProvider.Gemini => "gemini",
|
||||
LlmProvider.OpenAI => "openai",
|
||||
LlmProvider.Claude => "claude",
|
||||
_ => throw new ArgumentException($"Unknown LlmProvider enum value: {provider}")
|
||||
};
|
||||
}
|
||||
|
||||
private LlmProvider? ParseProviderString(string? providerString)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(providerString))
|
||||
return null;
|
||||
|
||||
// Try parsing as enum (case-insensitive)
|
||||
if (Enum.TryParse<LlmProvider>(providerString, ignoreCase: true, out var parsedProvider))
|
||||
return parsedProvider;
|
||||
|
||||
// Fallback to lowercase string matching for backward compatibility
|
||||
return providerString.ToLowerInvariant() switch
|
||||
{
|
||||
"auto" => LlmProvider.Auto,
|
||||
"gemini" => LlmProvider.Gemini,
|
||||
"openai" => LlmProvider.OpenAI,
|
||||
"claude" => LlmProvider.Claude,
|
||||
_ => null
|
||||
};
|
||||
}
|
||||
}
|
||||
236
src/Managing.Application/LLM/McpService.cs
Normal file
236
src/Managing.Application/LLM/McpService.cs
Normal file
@@ -0,0 +1,236 @@
|
||||
using Managing.Application.Abstractions.Services;
|
||||
using Managing.Domain.Users;
|
||||
using Managing.Mcp.Tools;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using static Managing.Common.Enums;
|
||||
|
||||
namespace Managing.Application.LLM;
|
||||
|
||||
/// <summary>
|
||||
/// Service for executing Model Context Protocol (MCP) tools
|
||||
/// </summary>
|
||||
public class McpService : IMcpService
|
||||
{
|
||||
private readonly BacktestTools _backtestTools;
|
||||
private readonly ILogger<McpService> _logger;
|
||||
|
||||
public McpService(BacktestTools backtestTools, ILogger<McpService> logger)
|
||||
{
|
||||
_backtestTools = backtestTools;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<object> ExecuteToolAsync(User user, string toolName, Dictionary<string, object>? parameters = null)
|
||||
{
|
||||
_logger.LogInformation("Executing MCP tool: {ToolName} for user: {UserId}", toolName, user.Id);
|
||||
|
||||
try
|
||||
{
|
||||
return toolName.ToLowerInvariant() switch
|
||||
{
|
||||
"get_backtests_paginated" => await ExecuteGetBacktestsPaginated(user, parameters),
|
||||
_ => throw new InvalidOperationException($"Unknown tool: {toolName}")
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error executing MCP tool {ToolName} for user {UserId}", toolName, user.Id);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public Task<IEnumerable<McpToolDefinition>> GetAvailableToolsAsync()
|
||||
{
|
||||
var tools = new List<McpToolDefinition>
|
||||
{
|
||||
new McpToolDefinition
|
||||
{
|
||||
Name = "get_backtests_paginated",
|
||||
Description = "Retrieves paginated backtests with filtering and sorting capabilities. Supports filters for score, winrate, drawdown, tickers, indicators, duration, and trading type.",
|
||||
Parameters = new Dictionary<string, McpParameterDefinition>
|
||||
{
|
||||
["page"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "integer",
|
||||
Description = "Page number (defaults to 1)",
|
||||
Required = false,
|
||||
DefaultValue = 1
|
||||
},
|
||||
["pageSize"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "integer",
|
||||
Description = "Number of items per page (defaults to 50, max 100)",
|
||||
Required = false,
|
||||
DefaultValue = 50
|
||||
},
|
||||
["sortBy"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Field to sort by (Score, WinRate, GrowthPercentage, MaxDrawdown, SharpeRatio, FinalPnl, StartDate, EndDate, PositionCount)",
|
||||
Required = false,
|
||||
DefaultValue = "Score"
|
||||
},
|
||||
["sortOrder"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Sort order - 'asc' or 'desc' (defaults to 'desc')",
|
||||
Required = false,
|
||||
DefaultValue = "desc"
|
||||
},
|
||||
["scoreMin"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "number",
|
||||
Description = "Minimum score filter (0-100)",
|
||||
Required = false
|
||||
},
|
||||
["scoreMax"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "number",
|
||||
Description = "Maximum score filter (0-100)",
|
||||
Required = false
|
||||
},
|
||||
["winrateMin"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "integer",
|
||||
Description = "Minimum winrate filter (0-100)",
|
||||
Required = false
|
||||
},
|
||||
["winrateMax"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "integer",
|
||||
Description = "Maximum winrate filter (0-100)",
|
||||
Required = false
|
||||
},
|
||||
["maxDrawdownMax"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "number",
|
||||
Description = "Maximum drawdown filter",
|
||||
Required = false
|
||||
},
|
||||
["tickers"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Comma-separated list of tickers to filter by (e.g., 'BTC,ETH,SOL')",
|
||||
Required = false
|
||||
},
|
||||
["indicators"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Comma-separated list of indicators to filter by",
|
||||
Required = false
|
||||
},
|
||||
["durationMinDays"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "number",
|
||||
Description = "Minimum duration in days",
|
||||
Required = false
|
||||
},
|
||||
["durationMaxDays"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "number",
|
||||
Description = "Maximum duration in days",
|
||||
Required = false
|
||||
},
|
||||
["name"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Filter by name (contains search)",
|
||||
Required = false
|
||||
},
|
||||
["tradingType"] = new McpParameterDefinition
|
||||
{
|
||||
Type = "string",
|
||||
Description = "Trading type filter (Spot, Futures, BacktestSpot, BacktestFutures, Paper)",
|
||||
Required = false
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return Task.FromResult<IEnumerable<McpToolDefinition>>(tools);
|
||||
}
|
||||
|
||||
private async Task<object> ExecuteGetBacktestsPaginated(User user, Dictionary<string, object>? parameters)
|
||||
{
|
||||
var page = GetParameterValue<int>(parameters, "page", 1);
|
||||
var pageSize = GetParameterValue<int>(parameters, "pageSize", 50);
|
||||
var sortByString = GetParameterValue<string>(parameters, "sortBy", "Score");
|
||||
var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc");
|
||||
var scoreMin = GetParameterValue<double?>(parameters, "scoreMin", null);
|
||||
var scoreMax = GetParameterValue<double?>(parameters, "scoreMax", null);
|
||||
var winrateMin = GetParameterValue<int?>(parameters, "winrateMin", null);
|
||||
var winrateMax = GetParameterValue<int?>(parameters, "winrateMax", null);
|
||||
var maxDrawdownMax = GetParameterValue<decimal?>(parameters, "maxDrawdownMax", null);
|
||||
var tickers = GetParameterValue<string?>(parameters, "tickers", null);
|
||||
var indicators = GetParameterValue<string?>(parameters, "indicators", null);
|
||||
var durationMinDays = GetParameterValue<double?>(parameters, "durationMinDays", null);
|
||||
var durationMaxDays = GetParameterValue<double?>(parameters, "durationMaxDays", null);
|
||||
var name = GetParameterValue<string?>(parameters, "name", null);
|
||||
var tradingTypeString = GetParameterValue<string?>(parameters, "tradingType", null);
|
||||
|
||||
// Parse sortBy enum
|
||||
if (!Enum.TryParse<BacktestSortableColumn>(sortByString, true, out var sortBy))
|
||||
{
|
||||
sortBy = BacktestSortableColumn.Score;
|
||||
}
|
||||
|
||||
// Parse tradingType enum
|
||||
TradingType? tradingType = null;
|
||||
if (!string.IsNullOrWhiteSpace(tradingTypeString) &&
|
||||
Enum.TryParse<TradingType>(tradingTypeString, true, out var parsedTradingType))
|
||||
{
|
||||
tradingType = parsedTradingType;
|
||||
}
|
||||
|
||||
return await _backtestTools.GetBacktestsPaginated(
|
||||
user,
|
||||
page,
|
||||
pageSize,
|
||||
sortBy,
|
||||
sortOrder,
|
||||
scoreMin,
|
||||
scoreMax,
|
||||
winrateMin,
|
||||
winrateMax,
|
||||
maxDrawdownMax,
|
||||
tickers,
|
||||
indicators,
|
||||
durationMinDays,
|
||||
durationMaxDays,
|
||||
name,
|
||||
tradingType);
|
||||
}
|
||||
|
||||
private T GetParameterValue<T>(Dictionary<string, object>? parameters, string key, T defaultValue)
|
||||
{
|
||||
if (parameters == null || !parameters.ContainsKey(key))
|
||||
{
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var value = parameters[key];
|
||||
if (value == null)
|
||||
{
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
// Handle nullable types
|
||||
var targetType = typeof(T);
|
||||
var underlyingType = Nullable.GetUnderlyingType(targetType);
|
||||
|
||||
if (underlyingType != null)
|
||||
{
|
||||
// It's a nullable type
|
||||
return (T)Convert.ChangeType(value, underlyingType);
|
||||
}
|
||||
|
||||
return (T)Convert.ChangeType(value, targetType);
|
||||
}
|
||||
catch
|
||||
{
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
165
src/Managing.Application/LLM/Providers/ClaudeProvider.cs
Normal file
165
src/Managing.Application/LLM/Providers/ClaudeProvider.cs
Normal file
@@ -0,0 +1,165 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Managing.Application.Abstractions.Services;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Managing.Application.LLM.Providers;
|
||||
|
||||
/// <summary>
|
||||
/// Anthropic Claude API provider
|
||||
/// </summary>
|
||||
public class ClaudeProvider : ILlmProvider
|
||||
{
|
||||
private readonly string _apiKey;
|
||||
private readonly string _defaultModel;
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly ILogger _logger;
|
||||
private const string BaseUrl = "https://api.anthropic.com/v1";
|
||||
private const string FallbackModel = "claude-3-5-sonnet-20241022";
|
||||
private const string AnthropicVersion = "2023-06-01";
|
||||
|
||||
public string Name => "claude";
|
||||
|
||||
public ClaudeProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
|
||||
{
|
||||
_apiKey = apiKey;
|
||||
_defaultModel = defaultModel ?? FallbackModel;
|
||||
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
|
||||
_httpClient.DefaultRequestHeaders.Add("x-api-key", _apiKey);
|
||||
_httpClient.DefaultRequestHeaders.Add("anthropic-version", AnthropicVersion);
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
|
||||
{
|
||||
var url = $"{BaseUrl}/messages";
|
||||
|
||||
// Extract system message
|
||||
var systemMessage = request.Messages.FirstOrDefault(m => m.Role == "system")?.Content ?? "";
|
||||
var messages = request.Messages.Where(m => m.Role != "system").ToList();
|
||||
|
||||
var claudeRequest = new
|
||||
{
|
||||
model = _defaultModel,
|
||||
max_tokens = request.MaxTokens,
|
||||
temperature = request.Temperature,
|
||||
system = !string.IsNullOrWhiteSpace(systemMessage) ? systemMessage : null,
|
||||
messages = messages.Select(m => new
|
||||
{
|
||||
role = m.Role == "assistant" ? "assistant" : "user",
|
||||
content = m.Content
|
||||
}).ToArray(),
|
||||
tools = request.Tools?.Any() == true ? request.Tools.Select(t => new
|
||||
{
|
||||
name = t.Name,
|
||||
description = t.Description,
|
||||
input_schema = new
|
||||
{
|
||||
type = "object",
|
||||
properties = t.Parameters.ToDictionary(
|
||||
p => p.Key,
|
||||
p => new
|
||||
{
|
||||
type = p.Value.Type,
|
||||
description = p.Value.Description
|
||||
}
|
||||
),
|
||||
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
|
||||
}
|
||||
}).ToArray() : null
|
||||
};
|
||||
|
||||
var jsonOptions = new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
};
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(url, claudeRequest, jsonOptions);
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
var errorContent = await response.Content.ReadAsStringAsync();
|
||||
_logger.LogError("Claude API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
|
||||
throw new HttpRequestException($"Claude API error: {response.StatusCode} - {errorContent}");
|
||||
}
|
||||
|
||||
var claudeResponse = await response.Content.ReadFromJsonAsync<ClaudeResponse>(jsonOptions);
|
||||
return ConvertFromClaudeResponse(claudeResponse!);
|
||||
}
|
||||
|
||||
private LlmChatResponse ConvertFromClaudeResponse(ClaudeResponse response)
|
||||
{
|
||||
var textContent = response.Content?.FirstOrDefault(c => c.Type == "text");
|
||||
var toolUseContents = response.Content?.Where(c => c.Type == "tool_use").ToList();
|
||||
|
||||
var llmResponse = new LlmChatResponse
|
||||
{
|
||||
Content = textContent?.Text ?? "",
|
||||
Provider = Name,
|
||||
Model = response.Model ?? _defaultModel,
|
||||
Usage = response.Usage != null ? new LlmUsage
|
||||
{
|
||||
PromptTokens = response.Usage.InputTokens,
|
||||
CompletionTokens = response.Usage.OutputTokens,
|
||||
TotalTokens = response.Usage.InputTokens + response.Usage.OutputTokens
|
||||
} : null
|
||||
};
|
||||
|
||||
if (toolUseContents?.Any() == true)
|
||||
{
|
||||
llmResponse.ToolCalls = toolUseContents.Select(tc => new LlmToolCall
|
||||
{
|
||||
Id = tc.Id ?? Guid.NewGuid().ToString(),
|
||||
Name = tc.Name ?? "",
|
||||
Arguments = tc.Input ?? new Dictionary<string, object>()
|
||||
}).ToList();
|
||||
llmResponse.RequiresToolExecution = true;
|
||||
}
|
||||
|
||||
return llmResponse;
|
||||
}
|
||||
|
||||
private class ClaudeResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public List<ClaudeContent>? Content { get; set; }
|
||||
|
||||
[JsonPropertyName("usage")]
|
||||
public ClaudeUsage? Usage { get; set; }
|
||||
}
|
||||
|
||||
private class ClaudeContent
|
||||
{
|
||||
[JsonPropertyName("type")]
|
||||
public string Type { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("text")]
|
||||
public string? Text { get; set; }
|
||||
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("name")]
|
||||
public string? Name { get; set; }
|
||||
|
||||
[JsonPropertyName("input")]
|
||||
public Dictionary<string, object>? Input { get; set; }
|
||||
}
|
||||
|
||||
private class ClaudeUsage
|
||||
{
|
||||
[JsonPropertyName("input_tokens")]
|
||||
public int InputTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("output_tokens")]
|
||||
public int OutputTokens { get; set; }
|
||||
}
|
||||
}
|
||||
210
src/Managing.Application/LLM/Providers/GeminiProvider.cs
Normal file
210
src/Managing.Application/LLM/Providers/GeminiProvider.cs
Normal file
@@ -0,0 +1,210 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Managing.Application.Abstractions.Services;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Managing.Application.LLM.Providers;
|
||||
|
||||
/// <summary>
|
||||
/// Google Gemini API provider
|
||||
/// </summary>
|
||||
public class GeminiProvider : ILlmProvider
|
||||
{
|
||||
private readonly string _apiKey;
|
||||
private readonly string _defaultModel;
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly ILogger _logger;
|
||||
private const string BaseUrl = "https://generativelanguage.googleapis.com/v1beta";
|
||||
private const string FallbackModel = "gemini-2.0-flash-exp";
|
||||
|
||||
public string Name => "gemini";
|
||||
|
||||
public GeminiProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
|
||||
{
|
||||
_apiKey = apiKey;
|
||||
_defaultModel = defaultModel ?? FallbackModel;
|
||||
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
|
||||
{
|
||||
var model = _defaultModel;
|
||||
var url = $"{BaseUrl}/models/{model}:generateContent?key={_apiKey}";
|
||||
|
||||
var geminiRequest = ConvertToGeminiRequest(request);
|
||||
var jsonOptions = new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
};
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(url, geminiRequest, jsonOptions);
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
var errorContent = await response.Content.ReadAsStringAsync();
|
||||
_logger.LogError("Gemini API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
|
||||
throw new HttpRequestException($"Gemini API error: {response.StatusCode} - {errorContent}");
|
||||
}
|
||||
|
||||
var geminiResponse = await response.Content.ReadFromJsonAsync<GeminiResponse>(jsonOptions);
|
||||
return ConvertFromGeminiResponse(geminiResponse!);
|
||||
}
|
||||
|
||||
private object ConvertToGeminiRequest(LlmChatRequest request)
|
||||
{
|
||||
var contents = request.Messages
|
||||
.Where(m => m.Role != "system") // Gemini doesn't support system messages in the same way
|
||||
.Select(m => new
|
||||
{
|
||||
role = m.Role == "assistant" ? "model" : "user",
|
||||
parts = new[]
|
||||
{
|
||||
new { text = m.Content }
|
||||
}
|
||||
}).ToList();
|
||||
|
||||
// Add system message as first user message if present
|
||||
var systemMessage = request.Messages.FirstOrDefault(m => m.Role == "system");
|
||||
if (systemMessage != null && !string.IsNullOrWhiteSpace(systemMessage.Content))
|
||||
{
|
||||
contents.Insert(0, new
|
||||
{
|
||||
role = "user",
|
||||
parts = new[]
|
||||
{
|
||||
new { text = $"System instructions: {systemMessage.Content}" }
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
var geminiRequest = new
|
||||
{
|
||||
contents,
|
||||
generationConfig = new
|
||||
{
|
||||
temperature = request.Temperature,
|
||||
maxOutputTokens = request.MaxTokens
|
||||
},
|
||||
tools = request.Tools?.Any() == true
|
||||
? new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
functionDeclarations = request.Tools.Select(t => new
|
||||
{
|
||||
name = t.Name,
|
||||
description = t.Description,
|
||||
parameters = new
|
||||
{
|
||||
type = "object",
|
||||
properties = t.Parameters.ToDictionary(
|
||||
p => p.Key,
|
||||
p => new
|
||||
{
|
||||
type = p.Value.Type,
|
||||
description = p.Value.Description
|
||||
}
|
||||
),
|
||||
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
|
||||
}
|
||||
}).ToArray()
|
||||
}
|
||||
}
|
||||
: null
|
||||
};
|
||||
|
||||
return geminiRequest;
|
||||
}
|
||||
|
||||
private LlmChatResponse ConvertFromGeminiResponse(GeminiResponse response)
|
||||
{
|
||||
var candidate = response.Candidates?.FirstOrDefault();
|
||||
if (candidate == null)
|
||||
{
|
||||
return new LlmChatResponse
|
||||
{
|
||||
Content = "",
|
||||
Provider = Name,
|
||||
Model = _defaultModel
|
||||
};
|
||||
}
|
||||
|
||||
var content = candidate.Content;
|
||||
var textPart = content?.Parts?.FirstOrDefault(p => !string.IsNullOrWhiteSpace(p.Text));
|
||||
var functionCallParts = content?.Parts?.Where(p => p.FunctionCall != null).ToList();
|
||||
|
||||
var llmResponse = new LlmChatResponse
|
||||
{
|
||||
Content = textPart?.Text ?? "",
|
||||
Provider = Name,
|
||||
Model = _defaultModel,
|
||||
Usage = response.UsageMetadata != null
|
||||
? new LlmUsage
|
||||
{
|
||||
PromptTokens = response.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens = response.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens = response.UsageMetadata.TotalTokenCount
|
||||
}
|
||||
: null
|
||||
};
|
||||
|
||||
// Handle function calls (tool calls)
|
||||
if (functionCallParts?.Any() == true)
|
||||
{
|
||||
llmResponse.ToolCalls = functionCallParts.Select((fc, idx) => new LlmToolCall
|
||||
{
|
||||
Id = $"call_{idx}",
|
||||
Name = fc.FunctionCall!.Name,
|
||||
Arguments = fc.FunctionCall.Args ?? new Dictionary<string, object>()
|
||||
}).ToList();
|
||||
llmResponse.RequiresToolExecution = true;
|
||||
}
|
||||
|
||||
return llmResponse;
|
||||
}
|
||||
|
||||
// Gemini API response models
|
||||
private class GeminiResponse
|
||||
{
|
||||
[JsonPropertyName("candidates")] public List<GeminiCandidate>? Candidates { get; set; }
|
||||
|
||||
[JsonPropertyName("usageMetadata")] public GeminiUsageMetadata? UsageMetadata { get; set; }
|
||||
}
|
||||
|
||||
private class GeminiCandidate
|
||||
{
|
||||
[JsonPropertyName("content")] public GeminiContent? Content { get; set; }
|
||||
}
|
||||
|
||||
private class GeminiContent
|
||||
{
|
||||
[JsonPropertyName("parts")] public List<GeminiPart>? Parts { get; set; }
|
||||
}
|
||||
|
||||
private class GeminiPart
|
||||
{
|
||||
[JsonPropertyName("text")] public string? Text { get; set; }
|
||||
|
||||
[JsonPropertyName("functionCall")] public GeminiFunctionCall? FunctionCall { get; set; }
|
||||
}
|
||||
|
||||
private class GeminiFunctionCall
|
||||
{
|
||||
[JsonPropertyName("name")] public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("args")] public Dictionary<string, object>? Args { get; set; }
|
||||
}
|
||||
|
||||
private class GeminiUsageMetadata
|
||||
{
|
||||
[JsonPropertyName("promptTokenCount")] public int PromptTokenCount { get; set; }
|
||||
|
||||
[JsonPropertyName("candidatesTokenCount")]
|
||||
public int CandidatesTokenCount { get; set; }
|
||||
|
||||
[JsonPropertyName("totalTokenCount")] public int TotalTokenCount { get; set; }
|
||||
}
|
||||
}
|
||||
21
src/Managing.Application/LLM/Providers/ILlmProvider.cs
Normal file
21
src/Managing.Application/LLM/Providers/ILlmProvider.cs
Normal file
@@ -0,0 +1,21 @@
|
||||
using Managing.Application.Abstractions.Services;
|
||||
|
||||
namespace Managing.Application.LLM.Providers;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for LLM provider implementations
|
||||
/// </summary>
|
||||
public interface ILlmProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the name of the provider (e.g., "gemini", "openai", "claude")
|
||||
/// </summary>
|
||||
string Name { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Sends a chat request to the provider
|
||||
/// </summary>
|
||||
/// <param name="request">The chat request</param>
|
||||
/// <returns>The chat response</returns>
|
||||
Task<LlmChatResponse> ChatAsync(LlmChatRequest request);
|
||||
}
|
||||
199
src/Managing.Application/LLM/Providers/OpenAiProvider.cs
Normal file
199
src/Managing.Application/LLM/Providers/OpenAiProvider.cs
Normal file
@@ -0,0 +1,199 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Managing.Application.Abstractions.Services;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Managing.Application.LLM.Providers;
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI API provider
|
||||
/// </summary>
|
||||
public class OpenAiProvider : ILlmProvider
|
||||
{
|
||||
private readonly string _apiKey;
|
||||
private readonly string _defaultModel;
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly ILogger _logger;
|
||||
private const string BaseUrl = "https://api.openai.com/v1";
|
||||
private const string FallbackModel = "gpt-4o";
|
||||
|
||||
public string Name => "openai";
|
||||
|
||||
public OpenAiProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
|
||||
{
|
||||
_apiKey = apiKey;
|
||||
_defaultModel = defaultModel ?? FallbackModel;
|
||||
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
|
||||
_httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {_apiKey}");
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
|
||||
{
|
||||
var url = $"{BaseUrl}/chat/completions";
|
||||
|
||||
var openAiRequest = new
|
||||
{
|
||||
model = _defaultModel,
|
||||
messages = request.Messages.Select(m => new
|
||||
{
|
||||
role = m.Role,
|
||||
content = m.Content,
|
||||
tool_calls = m.ToolCalls?.Select(tc => new
|
||||
{
|
||||
id = tc.Id,
|
||||
type = "function",
|
||||
function = new
|
||||
{
|
||||
name = tc.Name,
|
||||
arguments = JsonSerializer.Serialize(tc.Arguments)
|
||||
}
|
||||
}),
|
||||
tool_call_id = m.ToolCallId
|
||||
}).ToArray(),
|
||||
temperature = request.Temperature,
|
||||
max_tokens = request.MaxTokens,
|
||||
tools = request.Tools?.Any() == true ? request.Tools.Select(t => new
|
||||
{
|
||||
type = "function",
|
||||
function = new
|
||||
{
|
||||
name = t.Name,
|
||||
description = t.Description,
|
||||
parameters = new
|
||||
{
|
||||
type = "object",
|
||||
properties = t.Parameters.ToDictionary(
|
||||
p => p.Key,
|
||||
p => new
|
||||
{
|
||||
type = p.Value.Type,
|
||||
description = p.Value.Description
|
||||
}
|
||||
),
|
||||
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
|
||||
}
|
||||
}
|
||||
}).ToArray() : null
|
||||
};
|
||||
|
||||
var jsonOptions = new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
};
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(url, openAiRequest, jsonOptions);
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
var errorContent = await response.Content.ReadAsStringAsync();
|
||||
_logger.LogError("OpenAI API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
|
||||
throw new HttpRequestException($"OpenAI API error: {response.StatusCode} - {errorContent}");
|
||||
}
|
||||
|
||||
var openAiResponse = await response.Content.ReadFromJsonAsync<OpenAiResponse>(jsonOptions);
|
||||
return ConvertFromOpenAiResponse(openAiResponse!);
|
||||
}
|
||||
|
||||
private LlmChatResponse ConvertFromOpenAiResponse(OpenAiResponse response)
|
||||
{
|
||||
var choice = response.Choices?.FirstOrDefault();
|
||||
if (choice == null)
|
||||
{
|
||||
return new LlmChatResponse
|
||||
{
|
||||
Content = "",
|
||||
Provider = Name,
|
||||
Model = response.Model ?? _defaultModel
|
||||
};
|
||||
}
|
||||
|
||||
var llmResponse = new LlmChatResponse
|
||||
{
|
||||
Content = choice.Message?.Content ?? "",
|
||||
Provider = Name,
|
||||
Model = response.Model ?? _defaultModel,
|
||||
Usage = response.Usage != null ? new LlmUsage
|
||||
{
|
||||
PromptTokens = response.Usage.PromptTokens,
|
||||
CompletionTokens = response.Usage.CompletionTokens,
|
||||
TotalTokens = response.Usage.TotalTokens
|
||||
} : null
|
||||
};
|
||||
|
||||
if (choice.Message?.ToolCalls?.Any() == true)
|
||||
{
|
||||
llmResponse.ToolCalls = choice.Message.ToolCalls.Select(tc => new LlmToolCall
|
||||
{
|
||||
Id = tc.Id,
|
||||
Name = tc.Function.Name,
|
||||
Arguments = JsonSerializer.Deserialize<Dictionary<string, object>>(tc.Function.Arguments) ?? new()
|
||||
}).ToList();
|
||||
llmResponse.RequiresToolExecution = true;
|
||||
}
|
||||
|
||||
return llmResponse;
|
||||
}
|
||||
|
||||
private class OpenAiResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("choices")]
|
||||
public List<OpenAiChoice>? Choices { get; set; }
|
||||
|
||||
[JsonPropertyName("usage")]
|
||||
public OpenAiUsage? Usage { get; set; }
|
||||
}
|
||||
|
||||
private class OpenAiChoice
|
||||
{
|
||||
[JsonPropertyName("message")]
|
||||
public OpenAiMessage? Message { get; set; }
|
||||
}
|
||||
|
||||
private class OpenAiMessage
|
||||
{
|
||||
[JsonPropertyName("content")]
|
||||
public string? Content { get; set; }
|
||||
|
||||
[JsonPropertyName("tool_calls")]
|
||||
public List<OpenAiToolCall>? ToolCalls { get; set; }
|
||||
}
|
||||
|
||||
private class OpenAiToolCall
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string Id { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("function")]
|
||||
public OpenAiFunction Function { get; set; } = new();
|
||||
}
|
||||
|
||||
private class OpenAiFunction
|
||||
{
|
||||
[JsonPropertyName("name")]
|
||||
public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("arguments")]
|
||||
public string Arguments { get; set; } = "{}";
|
||||
}
|
||||
|
||||
private class OpenAiUsage
|
||||
{
|
||||
[JsonPropertyName("prompt_tokens")]
|
||||
public int PromptTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("completion_tokens")]
|
||||
public int CompletionTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("total_tokens")]
|
||||
public int TotalTokens { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@
|
||||
<ProjectReference Include="..\Managing.Common\Managing.Common.csproj"/>
|
||||
<ProjectReference Include="..\Managing.Domain\Managing.Domain.csproj"/>
|
||||
<ProjectReference Include="..\Managing.Infrastructure.Database\Managing.Infrastructure.Databases.csproj"/>
|
||||
<ProjectReference Include="..\Managing.Mcp\Managing.Mcp.csproj"/>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -339,6 +339,22 @@ public class UserService : IUserService
|
||||
return user;
|
||||
}
|
||||
|
||||
public async Task<User> UpdateDefaultLlmProvider(User user, LlmProvider defaultLlmProvider)
|
||||
{
|
||||
user = await GetUserByName(user.Name);
|
||||
if (user.DefaultLlmProvider == defaultLlmProvider)
|
||||
return user;
|
||||
|
||||
// Update the default LLM provider on the provided user object
|
||||
user.DefaultLlmProvider = defaultLlmProvider;
|
||||
await _userRepository.SaveOrUpdateUserAsync(user);
|
||||
|
||||
_logger.LogInformation("Updated default LLM provider to {Provider} for user {UserId}",
|
||||
defaultLlmProvider, user.Id);
|
||||
|
||||
return user;
|
||||
}
|
||||
|
||||
public async Task<User> UpdateUserSettings(User user, UserSettingsDto settings)
|
||||
{
|
||||
user = await GetUserByName(user.Name);
|
||||
|
||||
Reference in New Issue
Block a user