Implement LLM provider configuration and update user settings

- Added functionality to update the default LLM provider for users via a new endpoint in UserController.
- Introduced LlmProvider enum to manage available LLM options: Auto, Gemini, OpenAI, and Claude.
- Updated User and UserEntity models to include DefaultLlmProvider property.
- Enhanced database context and migrations to support the new LLM provider configuration.
- Integrated LLM services into the application bootstrap for dependency injection.
- Updated TypeScript API client to include methods for managing LLM providers and chat requests.
This commit is contained in:
2026-01-03 21:55:55 +07:00
parent fb49190346
commit 6f55566db3
46 changed files with 7900 additions and 3 deletions

View File

@@ -0,0 +1,210 @@
using Managing.Application.Abstractions.Services;
using Managing.Application.LLM.Providers;
using Managing.Domain.Users;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using static Managing.Common.Enums;
namespace Managing.Application.LLM;
/// <summary>
/// Service for interacting with LLM providers with auto-selection and BYOK support
/// </summary>
public class LlmService : ILlmService
{
private readonly IConfiguration _configuration;
private readonly ILogger<LlmService> _logger;
private readonly Dictionary<string, ILlmProvider> _providers;
public LlmService(
IConfiguration configuration,
ILogger<LlmService> logger,
IHttpClientFactory httpClientFactory)
{
_configuration = configuration;
_logger = logger;
_providers = new Dictionary<string, ILlmProvider>(StringComparer.OrdinalIgnoreCase);
// Initialize providers
InitializeProviders(httpClientFactory);
}
private void InitializeProviders(IHttpClientFactory httpClientFactory)
{
// Gemini Provider
var geminiApiKey = _configuration["Llm:Gemini:ApiKey"];
var geminiModel = _configuration["Llm:Gemini:DefaultModel"];
if (!string.IsNullOrWhiteSpace(geminiApiKey))
{
var providerKey = ConvertLlmProviderToString(LlmProvider.Gemini);
_providers[providerKey] = new GeminiProvider(geminiApiKey, geminiModel, httpClientFactory, _logger);
_logger.LogInformation("Gemini provider initialized with model: {Model}", geminiModel ?? "default");
}
// OpenAI Provider
var openaiApiKey = _configuration["Llm:OpenAI:ApiKey"];
var openaiModel = _configuration["Llm:OpenAI:DefaultModel"];
if (!string.IsNullOrWhiteSpace(openaiApiKey))
{
var providerKey = ConvertLlmProviderToString(LlmProvider.OpenAI);
_providers[providerKey] = new OpenAiProvider(openaiApiKey, openaiModel, httpClientFactory, _logger);
_logger.LogInformation("OpenAI provider initialized with model: {Model}", openaiModel ?? "default");
}
// Claude Provider
var claudeApiKey = _configuration["Llm:Claude:ApiKey"];
var claudeModel = _configuration["Llm:Claude:DefaultModel"];
if (!string.IsNullOrWhiteSpace(claudeApiKey))
{
var providerKey = ConvertLlmProviderToString(LlmProvider.Claude);
_providers[providerKey] = new ClaudeProvider(claudeApiKey, claudeModel, httpClientFactory, _logger);
_logger.LogInformation("Claude provider initialized with model: {Model}", claudeModel ?? "default");
}
if (_providers.Count == 0)
{
_logger.LogWarning("No LLM providers configured. Please add API keys to configuration.");
}
}
public async Task<LlmChatResponse> ChatAsync(User user, LlmChatRequest request)
{
ILlmProvider provider;
// BYOK: If user provides their own API key
if (!string.IsNullOrWhiteSpace(request.ApiKey))
{
var requestedProvider = ParseProviderString(request.Provider) ?? LlmProvider.Claude; // Default to Claude for BYOK
var providerName = ConvertLlmProviderToString(requestedProvider);
provider = CreateProviderWithCustomKey(requestedProvider, request.ApiKey);
_logger.LogInformation("Using BYOK for provider: {Provider} for user: {UserId}", providerName, user.Id);
}
// Auto mode: Select provider automatically (use user's default if set, otherwise fallback to system default)
else if (string.IsNullOrWhiteSpace(request.Provider) ||
ParseProviderString(request.Provider) == LlmProvider.Auto)
{
// Check if user has a default provider preference (and it's not Auto)
if (user.DefaultLlmProvider.HasValue &&
user.DefaultLlmProvider.Value != LlmProvider.Auto)
{
var providerName = ConvertLlmProviderToString(user.DefaultLlmProvider.Value);
if (_providers.TryGetValue(providerName, out var userPreferredProvider))
{
provider = userPreferredProvider;
_logger.LogInformation("Using user's default provider: {Provider} for user: {UserId}", provider.Name, user.Id);
}
else
{
provider = SelectProvider();
_logger.LogInformation("Auto-selected provider: {Provider} for user: {UserId} (user default {UserDefault} not available)",
provider.Name, user.Id, user.DefaultLlmProvider.Value);
}
}
else
{
provider = SelectProvider();
_logger.LogInformation("Auto-selected provider: {Provider} for user: {UserId} (user default: {UserDefault})",
provider.Name, user.Id, user.DefaultLlmProvider?.ToString() ?? "not set");
}
}
// Explicit provider selection
else
{
var requestedProvider = ParseProviderString(request.Provider);
if (requestedProvider == null || requestedProvider == LlmProvider.Auto)
{
throw new InvalidOperationException($"Invalid provider '{request.Provider}'. Valid providers are: {string.Join(", ", Enum.GetNames<LlmProvider>())}");
}
var providerName = ConvertLlmProviderToString(requestedProvider.Value);
if (!_providers.TryGetValue(providerName, out provider!))
{
throw new InvalidOperationException($"Provider '{request.Provider}' is not available or not configured.");
}
_logger.LogInformation("Using specified provider: {Provider} for user: {UserId}", providerName, user.Id);
}
try
{
var response = await provider.ChatAsync(request);
return response;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error calling LLM provider {Provider} for user {UserId}", provider.Name, user.Id);
throw;
}
}
public Task<IEnumerable<string>> GetAvailableProvidersAsync()
{
return Task.FromResult(_providers.Keys.AsEnumerable());
}
private ILlmProvider SelectProvider()
{
// Priority: OpenAI > Claude > Gemini
var openaiKey = ConvertLlmProviderToString(LlmProvider.OpenAI);
if (_providers.TryGetValue(openaiKey, out var openai))
return openai;
var claudeKey = ConvertLlmProviderToString(LlmProvider.Claude);
if (_providers.TryGetValue(claudeKey, out var claude))
return claude;
var geminiKey = ConvertLlmProviderToString(LlmProvider.Gemini);
if (_providers.TryGetValue(geminiKey, out var gemini))
return gemini;
throw new InvalidOperationException("No LLM providers are configured. Please add API keys to configuration.");
}
private ILlmProvider CreateProviderWithCustomKey(LlmProvider provider, string apiKey)
{
// This is a temporary instance with user's API key
// Get default models from configuration
var geminiModel = _configuration["Llm:Gemini:DefaultModel"];
var openaiModel = _configuration["Llm:OpenAI:DefaultModel"];
var claudeModel = _configuration["Llm:Claude:DefaultModel"];
return provider switch
{
LlmProvider.Gemini => new GeminiProvider(apiKey, geminiModel, null!, _logger),
LlmProvider.OpenAI => new OpenAiProvider(apiKey, openaiModel, null!, _logger),
LlmProvider.Claude => new ClaudeProvider(apiKey, claudeModel, null!, _logger),
_ => throw new InvalidOperationException($"Cannot create provider with custom key for: {provider}. Only Gemini, OpenAI, and Claude are supported for BYOK.")
};
}
private string ConvertLlmProviderToString(LlmProvider provider)
{
return provider switch
{
LlmProvider.Auto => "auto",
LlmProvider.Gemini => "gemini",
LlmProvider.OpenAI => "openai",
LlmProvider.Claude => "claude",
_ => throw new ArgumentException($"Unknown LlmProvider enum value: {provider}")
};
}
private LlmProvider? ParseProviderString(string? providerString)
{
if (string.IsNullOrWhiteSpace(providerString))
return null;
// Try parsing as enum (case-insensitive)
if (Enum.TryParse<LlmProvider>(providerString, ignoreCase: true, out var parsedProvider))
return parsedProvider;
// Fallback to lowercase string matching for backward compatibility
return providerString.ToLowerInvariant() switch
{
"auto" => LlmProvider.Auto,
"gemini" => LlmProvider.Gemini,
"openai" => LlmProvider.OpenAI,
"claude" => LlmProvider.Claude,
_ => null
};
}
}

View File

@@ -0,0 +1,236 @@
using Managing.Application.Abstractions.Services;
using Managing.Domain.Users;
using Managing.Mcp.Tools;
using Microsoft.Extensions.Logging;
using static Managing.Common.Enums;
namespace Managing.Application.LLM;
/// <summary>
/// Service for executing Model Context Protocol (MCP) tools
/// </summary>
public class McpService : IMcpService
{
private readonly BacktestTools _backtestTools;
private readonly ILogger<McpService> _logger;
public McpService(BacktestTools backtestTools, ILogger<McpService> logger)
{
_backtestTools = backtestTools;
_logger = logger;
}
public async Task<object> ExecuteToolAsync(User user, string toolName, Dictionary<string, object>? parameters = null)
{
_logger.LogInformation("Executing MCP tool: {ToolName} for user: {UserId}", toolName, user.Id);
try
{
return toolName.ToLowerInvariant() switch
{
"get_backtests_paginated" => await ExecuteGetBacktestsPaginated(user, parameters),
_ => throw new InvalidOperationException($"Unknown tool: {toolName}")
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Error executing MCP tool {ToolName} for user {UserId}", toolName, user.Id);
throw;
}
}
public Task<IEnumerable<McpToolDefinition>> GetAvailableToolsAsync()
{
var tools = new List<McpToolDefinition>
{
new McpToolDefinition
{
Name = "get_backtests_paginated",
Description = "Retrieves paginated backtests with filtering and sorting capabilities. Supports filters for score, winrate, drawdown, tickers, indicators, duration, and trading type.",
Parameters = new Dictionary<string, McpParameterDefinition>
{
["page"] = new McpParameterDefinition
{
Type = "integer",
Description = "Page number (defaults to 1)",
Required = false,
DefaultValue = 1
},
["pageSize"] = new McpParameterDefinition
{
Type = "integer",
Description = "Number of items per page (defaults to 50, max 100)",
Required = false,
DefaultValue = 50
},
["sortBy"] = new McpParameterDefinition
{
Type = "string",
Description = "Field to sort by (Score, WinRate, GrowthPercentage, MaxDrawdown, SharpeRatio, FinalPnl, StartDate, EndDate, PositionCount)",
Required = false,
DefaultValue = "Score"
},
["sortOrder"] = new McpParameterDefinition
{
Type = "string",
Description = "Sort order - 'asc' or 'desc' (defaults to 'desc')",
Required = false,
DefaultValue = "desc"
},
["scoreMin"] = new McpParameterDefinition
{
Type = "number",
Description = "Minimum score filter (0-100)",
Required = false
},
["scoreMax"] = new McpParameterDefinition
{
Type = "number",
Description = "Maximum score filter (0-100)",
Required = false
},
["winrateMin"] = new McpParameterDefinition
{
Type = "integer",
Description = "Minimum winrate filter (0-100)",
Required = false
},
["winrateMax"] = new McpParameterDefinition
{
Type = "integer",
Description = "Maximum winrate filter (0-100)",
Required = false
},
["maxDrawdownMax"] = new McpParameterDefinition
{
Type = "number",
Description = "Maximum drawdown filter",
Required = false
},
["tickers"] = new McpParameterDefinition
{
Type = "string",
Description = "Comma-separated list of tickers to filter by (e.g., 'BTC,ETH,SOL')",
Required = false
},
["indicators"] = new McpParameterDefinition
{
Type = "string",
Description = "Comma-separated list of indicators to filter by",
Required = false
},
["durationMinDays"] = new McpParameterDefinition
{
Type = "number",
Description = "Minimum duration in days",
Required = false
},
["durationMaxDays"] = new McpParameterDefinition
{
Type = "number",
Description = "Maximum duration in days",
Required = false
},
["name"] = new McpParameterDefinition
{
Type = "string",
Description = "Filter by name (contains search)",
Required = false
},
["tradingType"] = new McpParameterDefinition
{
Type = "string",
Description = "Trading type filter (Spot, Futures, BacktestSpot, BacktestFutures, Paper)",
Required = false
}
}
}
};
return Task.FromResult<IEnumerable<McpToolDefinition>>(tools);
}
private async Task<object> ExecuteGetBacktestsPaginated(User user, Dictionary<string, object>? parameters)
{
var page = GetParameterValue<int>(parameters, "page", 1);
var pageSize = GetParameterValue<int>(parameters, "pageSize", 50);
var sortByString = GetParameterValue<string>(parameters, "sortBy", "Score");
var sortOrder = GetParameterValue<string>(parameters, "sortOrder", "desc");
var scoreMin = GetParameterValue<double?>(parameters, "scoreMin", null);
var scoreMax = GetParameterValue<double?>(parameters, "scoreMax", null);
var winrateMin = GetParameterValue<int?>(parameters, "winrateMin", null);
var winrateMax = GetParameterValue<int?>(parameters, "winrateMax", null);
var maxDrawdownMax = GetParameterValue<decimal?>(parameters, "maxDrawdownMax", null);
var tickers = GetParameterValue<string?>(parameters, "tickers", null);
var indicators = GetParameterValue<string?>(parameters, "indicators", null);
var durationMinDays = GetParameterValue<double?>(parameters, "durationMinDays", null);
var durationMaxDays = GetParameterValue<double?>(parameters, "durationMaxDays", null);
var name = GetParameterValue<string?>(parameters, "name", null);
var tradingTypeString = GetParameterValue<string?>(parameters, "tradingType", null);
// Parse sortBy enum
if (!Enum.TryParse<BacktestSortableColumn>(sortByString, true, out var sortBy))
{
sortBy = BacktestSortableColumn.Score;
}
// Parse tradingType enum
TradingType? tradingType = null;
if (!string.IsNullOrWhiteSpace(tradingTypeString) &&
Enum.TryParse<TradingType>(tradingTypeString, true, out var parsedTradingType))
{
tradingType = parsedTradingType;
}
return await _backtestTools.GetBacktestsPaginated(
user,
page,
pageSize,
sortBy,
sortOrder,
scoreMin,
scoreMax,
winrateMin,
winrateMax,
maxDrawdownMax,
tickers,
indicators,
durationMinDays,
durationMaxDays,
name,
tradingType);
}
private T GetParameterValue<T>(Dictionary<string, object>? parameters, string key, T defaultValue)
{
if (parameters == null || !parameters.ContainsKey(key))
{
return defaultValue;
}
try
{
var value = parameters[key];
if (value == null)
{
return defaultValue;
}
// Handle nullable types
var targetType = typeof(T);
var underlyingType = Nullable.GetUnderlyingType(targetType);
if (underlyingType != null)
{
// It's a nullable type
return (T)Convert.ChangeType(value, underlyingType);
}
return (T)Convert.ChangeType(value, targetType);
}
catch
{
return defaultValue;
}
}
}

View File

@@ -0,0 +1,165 @@
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using Managing.Application.Abstractions.Services;
using Microsoft.Extensions.Logging;
namespace Managing.Application.LLM.Providers;
/// <summary>
/// Anthropic Claude API provider
/// </summary>
public class ClaudeProvider : ILlmProvider
{
private readonly string _apiKey;
private readonly string _defaultModel;
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private const string BaseUrl = "https://api.anthropic.com/v1";
private const string FallbackModel = "claude-3-5-sonnet-20241022";
private const string AnthropicVersion = "2023-06-01";
public string Name => "claude";
public ClaudeProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
{
_apiKey = apiKey;
_defaultModel = defaultModel ?? FallbackModel;
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
_httpClient.DefaultRequestHeaders.Add("x-api-key", _apiKey);
_httpClient.DefaultRequestHeaders.Add("anthropic-version", AnthropicVersion);
_logger = logger;
}
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
{
var url = $"{BaseUrl}/messages";
// Extract system message
var systemMessage = request.Messages.FirstOrDefault(m => m.Role == "system")?.Content ?? "";
var messages = request.Messages.Where(m => m.Role != "system").ToList();
var claudeRequest = new
{
model = _defaultModel,
max_tokens = request.MaxTokens,
temperature = request.Temperature,
system = !string.IsNullOrWhiteSpace(systemMessage) ? systemMessage : null,
messages = messages.Select(m => new
{
role = m.Role == "assistant" ? "assistant" : "user",
content = m.Content
}).ToArray(),
tools = request.Tools?.Any() == true ? request.Tools.Select(t => new
{
name = t.Name,
description = t.Description,
input_schema = new
{
type = "object",
properties = t.Parameters.ToDictionary(
p => p.Key,
p => new
{
type = p.Value.Type,
description = p.Value.Description
}
),
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
}
}).ToArray() : null
};
var jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
var response = await _httpClient.PostAsJsonAsync(url, claudeRequest, jsonOptions);
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
_logger.LogError("Claude API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
throw new HttpRequestException($"Claude API error: {response.StatusCode} - {errorContent}");
}
var claudeResponse = await response.Content.ReadFromJsonAsync<ClaudeResponse>(jsonOptions);
return ConvertFromClaudeResponse(claudeResponse!);
}
private LlmChatResponse ConvertFromClaudeResponse(ClaudeResponse response)
{
var textContent = response.Content?.FirstOrDefault(c => c.Type == "text");
var toolUseContents = response.Content?.Where(c => c.Type == "tool_use").ToList();
var llmResponse = new LlmChatResponse
{
Content = textContent?.Text ?? "",
Provider = Name,
Model = response.Model ?? _defaultModel,
Usage = response.Usage != null ? new LlmUsage
{
PromptTokens = response.Usage.InputTokens,
CompletionTokens = response.Usage.OutputTokens,
TotalTokens = response.Usage.InputTokens + response.Usage.OutputTokens
} : null
};
if (toolUseContents?.Any() == true)
{
llmResponse.ToolCalls = toolUseContents.Select(tc => new LlmToolCall
{
Id = tc.Id ?? Guid.NewGuid().ToString(),
Name = tc.Name ?? "",
Arguments = tc.Input ?? new Dictionary<string, object>()
}).ToList();
llmResponse.RequiresToolExecution = true;
}
return llmResponse;
}
private class ClaudeResponse
{
[JsonPropertyName("id")]
public string? Id { get; set; }
[JsonPropertyName("model")]
public string? Model { get; set; }
[JsonPropertyName("content")]
public List<ClaudeContent>? Content { get; set; }
[JsonPropertyName("usage")]
public ClaudeUsage? Usage { get; set; }
}
private class ClaudeContent
{
[JsonPropertyName("type")]
public string Type { get; set; } = "";
[JsonPropertyName("text")]
public string? Text { get; set; }
[JsonPropertyName("id")]
public string? Id { get; set; }
[JsonPropertyName("name")]
public string? Name { get; set; }
[JsonPropertyName("input")]
public Dictionary<string, object>? Input { get; set; }
}
private class ClaudeUsage
{
[JsonPropertyName("input_tokens")]
public int InputTokens { get; set; }
[JsonPropertyName("output_tokens")]
public int OutputTokens { get; set; }
}
}

View File

@@ -0,0 +1,210 @@
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using Managing.Application.Abstractions.Services;
using Microsoft.Extensions.Logging;
namespace Managing.Application.LLM.Providers;
/// <summary>
/// Google Gemini API provider
/// </summary>
public class GeminiProvider : ILlmProvider
{
private readonly string _apiKey;
private readonly string _defaultModel;
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private const string BaseUrl = "https://generativelanguage.googleapis.com/v1beta";
private const string FallbackModel = "gemini-2.0-flash-exp";
public string Name => "gemini";
public GeminiProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
{
_apiKey = apiKey;
_defaultModel = defaultModel ?? FallbackModel;
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
_logger = logger;
}
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
{
var model = _defaultModel;
var url = $"{BaseUrl}/models/{model}:generateContent?key={_apiKey}";
var geminiRequest = ConvertToGeminiRequest(request);
var jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
var response = await _httpClient.PostAsJsonAsync(url, geminiRequest, jsonOptions);
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
_logger.LogError("Gemini API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
throw new HttpRequestException($"Gemini API error: {response.StatusCode} - {errorContent}");
}
var geminiResponse = await response.Content.ReadFromJsonAsync<GeminiResponse>(jsonOptions);
return ConvertFromGeminiResponse(geminiResponse!);
}
private object ConvertToGeminiRequest(LlmChatRequest request)
{
var contents = request.Messages
.Where(m => m.Role != "system") // Gemini doesn't support system messages in the same way
.Select(m => new
{
role = m.Role == "assistant" ? "model" : "user",
parts = new[]
{
new { text = m.Content }
}
}).ToList();
// Add system message as first user message if present
var systemMessage = request.Messages.FirstOrDefault(m => m.Role == "system");
if (systemMessage != null && !string.IsNullOrWhiteSpace(systemMessage.Content))
{
contents.Insert(0, new
{
role = "user",
parts = new[]
{
new { text = $"System instructions: {systemMessage.Content}" }
}
});
}
var geminiRequest = new
{
contents,
generationConfig = new
{
temperature = request.Temperature,
maxOutputTokens = request.MaxTokens
},
tools = request.Tools?.Any() == true
? new[]
{
new
{
functionDeclarations = request.Tools.Select(t => new
{
name = t.Name,
description = t.Description,
parameters = new
{
type = "object",
properties = t.Parameters.ToDictionary(
p => p.Key,
p => new
{
type = p.Value.Type,
description = p.Value.Description
}
),
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
}
}).ToArray()
}
}
: null
};
return geminiRequest;
}
private LlmChatResponse ConvertFromGeminiResponse(GeminiResponse response)
{
var candidate = response.Candidates?.FirstOrDefault();
if (candidate == null)
{
return new LlmChatResponse
{
Content = "",
Provider = Name,
Model = _defaultModel
};
}
var content = candidate.Content;
var textPart = content?.Parts?.FirstOrDefault(p => !string.IsNullOrWhiteSpace(p.Text));
var functionCallParts = content?.Parts?.Where(p => p.FunctionCall != null).ToList();
var llmResponse = new LlmChatResponse
{
Content = textPart?.Text ?? "",
Provider = Name,
Model = _defaultModel,
Usage = response.UsageMetadata != null
? new LlmUsage
{
PromptTokens = response.UsageMetadata.PromptTokenCount,
CompletionTokens = response.UsageMetadata.CandidatesTokenCount,
TotalTokens = response.UsageMetadata.TotalTokenCount
}
: null
};
// Handle function calls (tool calls)
if (functionCallParts?.Any() == true)
{
llmResponse.ToolCalls = functionCallParts.Select((fc, idx) => new LlmToolCall
{
Id = $"call_{idx}",
Name = fc.FunctionCall!.Name,
Arguments = fc.FunctionCall.Args ?? new Dictionary<string, object>()
}).ToList();
llmResponse.RequiresToolExecution = true;
}
return llmResponse;
}
// Gemini API response models
private class GeminiResponse
{
[JsonPropertyName("candidates")] public List<GeminiCandidate>? Candidates { get; set; }
[JsonPropertyName("usageMetadata")] public GeminiUsageMetadata? UsageMetadata { get; set; }
}
private class GeminiCandidate
{
[JsonPropertyName("content")] public GeminiContent? Content { get; set; }
}
private class GeminiContent
{
[JsonPropertyName("parts")] public List<GeminiPart>? Parts { get; set; }
}
private class GeminiPart
{
[JsonPropertyName("text")] public string? Text { get; set; }
[JsonPropertyName("functionCall")] public GeminiFunctionCall? FunctionCall { get; set; }
}
private class GeminiFunctionCall
{
[JsonPropertyName("name")] public string Name { get; set; } = "";
[JsonPropertyName("args")] public Dictionary<string, object>? Args { get; set; }
}
private class GeminiUsageMetadata
{
[JsonPropertyName("promptTokenCount")] public int PromptTokenCount { get; set; }
[JsonPropertyName("candidatesTokenCount")]
public int CandidatesTokenCount { get; set; }
[JsonPropertyName("totalTokenCount")] public int TotalTokenCount { get; set; }
}
}

View File

@@ -0,0 +1,21 @@
using Managing.Application.Abstractions.Services;
namespace Managing.Application.LLM.Providers;
/// <summary>
/// Interface for LLM provider implementations
/// </summary>
public interface ILlmProvider
{
/// <summary>
/// Gets the name of the provider (e.g., "gemini", "openai", "claude")
/// </summary>
string Name { get; }
/// <summary>
/// Sends a chat request to the provider
/// </summary>
/// <param name="request">The chat request</param>
/// <returns>The chat response</returns>
Task<LlmChatResponse> ChatAsync(LlmChatRequest request);
}

View File

@@ -0,0 +1,199 @@
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using Managing.Application.Abstractions.Services;
using Microsoft.Extensions.Logging;
namespace Managing.Application.LLM.Providers;
/// <summary>
/// OpenAI API provider
/// </summary>
public class OpenAiProvider : ILlmProvider
{
private readonly string _apiKey;
private readonly string _defaultModel;
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private const string BaseUrl = "https://api.openai.com/v1";
private const string FallbackModel = "gpt-4o";
public string Name => "openai";
public OpenAiProvider(string apiKey, string? defaultModel, IHttpClientFactory? httpClientFactory, ILogger logger)
{
_apiKey = apiKey;
_defaultModel = defaultModel ?? FallbackModel;
_httpClient = httpClientFactory?.CreateClient() ?? new HttpClient();
_httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {_apiKey}");
_logger = logger;
}
public async Task<LlmChatResponse> ChatAsync(LlmChatRequest request)
{
var url = $"{BaseUrl}/chat/completions";
var openAiRequest = new
{
model = _defaultModel,
messages = request.Messages.Select(m => new
{
role = m.Role,
content = m.Content,
tool_calls = m.ToolCalls?.Select(tc => new
{
id = tc.Id,
type = "function",
function = new
{
name = tc.Name,
arguments = JsonSerializer.Serialize(tc.Arguments)
}
}),
tool_call_id = m.ToolCallId
}).ToArray(),
temperature = request.Temperature,
max_tokens = request.MaxTokens,
tools = request.Tools?.Any() == true ? request.Tools.Select(t => new
{
type = "function",
function = new
{
name = t.Name,
description = t.Description,
parameters = new
{
type = "object",
properties = t.Parameters.ToDictionary(
p => p.Key,
p => new
{
type = p.Value.Type,
description = p.Value.Description
}
),
required = t.Parameters.Where(p => p.Value.Required).Select(p => p.Key).ToArray()
}
}
}).ToArray() : null
};
var jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
var response = await _httpClient.PostAsJsonAsync(url, openAiRequest, jsonOptions);
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
_logger.LogError("OpenAI API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
throw new HttpRequestException($"OpenAI API error: {response.StatusCode} - {errorContent}");
}
var openAiResponse = await response.Content.ReadFromJsonAsync<OpenAiResponse>(jsonOptions);
return ConvertFromOpenAiResponse(openAiResponse!);
}
private LlmChatResponse ConvertFromOpenAiResponse(OpenAiResponse response)
{
var choice = response.Choices?.FirstOrDefault();
if (choice == null)
{
return new LlmChatResponse
{
Content = "",
Provider = Name,
Model = response.Model ?? _defaultModel
};
}
var llmResponse = new LlmChatResponse
{
Content = choice.Message?.Content ?? "",
Provider = Name,
Model = response.Model ?? _defaultModel,
Usage = response.Usage != null ? new LlmUsage
{
PromptTokens = response.Usage.PromptTokens,
CompletionTokens = response.Usage.CompletionTokens,
TotalTokens = response.Usage.TotalTokens
} : null
};
if (choice.Message?.ToolCalls?.Any() == true)
{
llmResponse.ToolCalls = choice.Message.ToolCalls.Select(tc => new LlmToolCall
{
Id = tc.Id,
Name = tc.Function.Name,
Arguments = JsonSerializer.Deserialize<Dictionary<string, object>>(tc.Function.Arguments) ?? new()
}).ToList();
llmResponse.RequiresToolExecution = true;
}
return llmResponse;
}
private class OpenAiResponse
{
[JsonPropertyName("id")]
public string? Id { get; set; }
[JsonPropertyName("model")]
public string? Model { get; set; }
[JsonPropertyName("choices")]
public List<OpenAiChoice>? Choices { get; set; }
[JsonPropertyName("usage")]
public OpenAiUsage? Usage { get; set; }
}
private class OpenAiChoice
{
[JsonPropertyName("message")]
public OpenAiMessage? Message { get; set; }
}
private class OpenAiMessage
{
[JsonPropertyName("content")]
public string? Content { get; set; }
[JsonPropertyName("tool_calls")]
public List<OpenAiToolCall>? ToolCalls { get; set; }
}
private class OpenAiToolCall
{
[JsonPropertyName("id")]
public string Id { get; set; } = "";
[JsonPropertyName("function")]
public OpenAiFunction Function { get; set; } = new();
}
private class OpenAiFunction
{
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("arguments")]
public string Arguments { get; set; } = "{}";
}
private class OpenAiUsage
{
[JsonPropertyName("prompt_tokens")]
public int PromptTokens { get; set; }
[JsonPropertyName("completion_tokens")]
public int CompletionTokens { get; set; }
[JsonPropertyName("total_tokens")]
public int TotalTokens { get; set; }
}
}

View File

@@ -36,6 +36,7 @@
<ProjectReference Include="..\Managing.Common\Managing.Common.csproj"/>
<ProjectReference Include="..\Managing.Domain\Managing.Domain.csproj"/>
<ProjectReference Include="..\Managing.Infrastructure.Database\Managing.Infrastructure.Databases.csproj"/>
<ProjectReference Include="..\Managing.Mcp\Managing.Mcp.csproj"/>
</ItemGroup>
</Project>

View File

@@ -339,6 +339,22 @@ public class UserService : IUserService
return user;
}
public async Task<User> UpdateDefaultLlmProvider(User user, LlmProvider defaultLlmProvider)
{
user = await GetUserByName(user.Name);
if (user.DefaultLlmProvider == defaultLlmProvider)
return user;
// Update the default LLM provider on the provided user object
user.DefaultLlmProvider = defaultLlmProvider;
await _userRepository.SaveOrUpdateUserAsync(user);
_logger.LogInformation("Updated default LLM provider to {Provider} for user {UserId}",
defaultLlmProvider, user.Id);
return user;
}
public async Task<User> UpdateUserSettings(User user, UserSettingsDto settings)
{
user = await GetUserByName(user.Name);