diff --git a/src-tauri/src/commands/provider.rs b/src-tauri/src/commands/provider.rs index 4de9f54..30580e7 100644 --- a/src-tauri/src/commands/provider.rs +++ b/src-tauri/src/commands/provider.rs @@ -2,46 +2,16 @@ use std::collections::HashMap; -use serde::Deserialize; use tauri::State; use crate::app_config::AppType; -use crate::codex_config; -use crate::config::get_claude_settings_path; use crate::error::AppError; -use crate::provider::{Provider, ProviderMeta}; -use crate::services::{EndpointLatency, ProviderService, SpeedtestService}; +use crate::provider::Provider; +use crate::services::{EndpointLatency, ProviderService, ProviderSortUpdate, SpeedtestService}; use crate::store::AppState; -fn validate_provider_settings(app_type: &AppType, provider: &Provider) -> Result<(), String> { - match app_type { - AppType::Claude => { - if !provider.settings_config.is_object() { - return Err("Claude 配置必须是 JSON 对象".to_string()); - } - } - AppType::Codex => { - let settings = provider - .settings_config - .as_object() - .ok_or_else(|| "Codex 配置必须是 JSON 对象".to_string())?; - let auth = settings - .get("auth") - .ok_or_else(|| "Codex 配置缺少 auth 字段".to_string())?; - if !auth.is_object() { - return Err("Codex auth 配置必须是 JSON 对象".to_string()); - } - if let Some(config_value) = settings.get("config") { - if !(config_value.is_string() || config_value.is_null()) { - return Err("Codex config 字段必须是字符串".to_string()); - } - if let Some(cfg_text) = config_value.as_str() { - codex_config::validate_config_toml(cfg_text)?; - } - } - } - } - Ok(()) +fn missing_param(param: &str) -> String { + format!("缺少 {} 参数 (Missing {} parameter)", param, param) } /// 获取所有供应商 @@ -57,16 +27,7 @@ pub async fn get_providers( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - let config = state - .config - .read() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config - .get_manager(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - Ok(manager.get_all_providers().clone()) + ProviderService::list(state.inner(), app_type).map_err(|e| e.to_string()) } /// 获取当前供应商ID @@ -82,16 +43,7 @@ pub async fn get_current_provider( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - let config = state - .config - .read() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config - .get_manager(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - Ok(manager.current.clone()) + ProviderService::current(state.inner(), app_type).map_err(|e| e.to_string()) } /// 添加供应商 @@ -108,54 +60,7 @@ pub async fn add_provider( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - validate_provider_settings(&app_type, &provider)?; - - let is_current = { - let config = state - .config - .read() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = config - .get_manager(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - manager.current == provider.id - }; - - if is_current { - match app_type { - AppType::Claude => { - let settings_path = crate::config::get_claude_settings_path(); - crate::config::write_json_file(&settings_path, &provider.settings_config)?; - } - AppType::Codex => { - let auth = provider - .settings_config - .get("auth") - .ok_or_else(|| "目标供应商缺少 auth 配置".to_string())?; - let cfg_text = provider - .settings_config - .get("config") - .and_then(|v| v.as_str()); - crate::codex_config::write_codex_live_atomic(auth, cfg_text)?; - } - } - } - - { - let mut config = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - manager - .providers - .insert(provider.id.clone(), provider.clone()); - } - state.save()?; - - Ok(true) + ProviderService::add(state.inner(), app_type, provider).map_err(|e| e.to_string()) } /// 更新供应商 @@ -172,88 +77,7 @@ pub async fn update_provider( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - validate_provider_settings(&app_type, &provider)?; - - let (exists, is_current) = { - let config = state - .config - .read() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = config - .get_manager(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - ( - manager.providers.contains_key(&provider.id), - manager.current == provider.id, - ) - }; - if !exists { - return Err(format!("供应商不存在: {}", provider.id)); - } - - if is_current { - match app_type { - AppType::Claude => { - let settings_path = crate::config::get_claude_settings_path(); - crate::config::write_json_file(&settings_path, &provider.settings_config)?; - } - AppType::Codex => { - let auth = provider - .settings_config - .get("auth") - .ok_or_else(|| "目标供应商缺少 auth 配置".to_string())?; - let cfg_text = provider - .settings_config - .get("config") - .and_then(|v| v.as_str()); - crate::codex_config::write_codex_live_atomic(auth, cfg_text)?; - } - } - } - - { - let mut config = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - let merged_provider = if let Some(existing) = manager.providers.get(&provider.id) { - let mut updated = provider.clone(); - - match (existing.meta.as_ref(), updated.meta.take()) { - (Some(old_meta), None) => { - updated.meta = Some(old_meta.clone()); - } - (Some(old_meta), Some(mut new_meta)) => { - let mut merged_map = old_meta.custom_endpoints.clone(); - for (url, ep) in new_meta.custom_endpoints.drain() { - merged_map.entry(url).or_insert(ep); - } - updated.meta = Some(ProviderMeta { - custom_endpoints: merged_map, - usage_script: new_meta.usage_script.clone(), - }); - } - (None, maybe_new) => { - updated.meta = maybe_new; - } - } - - updated - } else { - provider.clone() - }; - - manager - .providers - .insert(merged_provider.id.clone(), merged_provider); - } - state.save()?; - - Ok(true) + ProviderService::update(state.inner(), app_type, provider).map_err(|e| e.to_string()) } /// 删除供应商 @@ -285,12 +109,7 @@ pub async fn delete_provider( /// 切换供应商 fn switch_provider_internal(state: &AppState, app_type: AppType, id: &str) -> Result<(), AppError> { - let mut config = state.config.write().map_err(AppError::from)?; - - ProviderService::switch(&mut config, app_type, id)?; - - drop(config); - state.save() + ProviderService::switch(state, app_type, id) } #[doc(hidden)] @@ -321,54 +140,7 @@ pub async fn switch_provider( } fn import_default_config_internal(state: &AppState, app_type: AppType) -> Result<(), AppError> { - { - let config = state.config.read()?; - if let Some(manager) = config.get_manager(&app_type) { - if !manager.get_all_providers().is_empty() { - // 已存在供应商则视为已导入,保持与原逻辑一致 - return Ok(()); - } - } - } - - let settings_config = match app_type { - AppType::Codex => { - let auth_path = codex_config::get_codex_auth_path(); - if !auth_path.exists() { - return Err(AppError::Message("Codex 配置文件不存在".to_string())); - } - let auth: serde_json::Value = crate::config::read_json_file(&auth_path)?; - let config_str = crate::codex_config::read_and_validate_codex_config_text()?; - serde_json::json!({ "auth": auth, "config": config_str }) - } - AppType::Claude => { - let settings_path = get_claude_settings_path(); - if !settings_path.exists() { - return Err(AppError::Message("Claude Code 配置文件不存在".to_string())); - } - crate::config::read_json_file(&settings_path)? - } - }; - - let provider = Provider::with_id( - "default".to_string(), - "default".to_string(), - settings_config, - None, - ); - - let mut config = state.config.write()?; - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| AppError::Message(format!("应用类型不存在: {:?}", app_type)))?; - - manager.providers.insert(provider.id.clone(), provider); - manager.current = "default".to_string(); - - drop(config); - state.save()?; - - Ok(()) + ProviderService::import_default_config(state, app_type) } #[doc(hidden)] @@ -407,130 +179,18 @@ pub async fn query_provider_usage( app: Option, appType: Option, ) -> Result { - use crate::provider::{UsageData, UsageResult}; - - let provider_id = provider_id.or(providerId).ok_or("缺少 providerId 参数")?; + let provider_id = provider_id + .or(providerId) + .ok_or_else(|| missing_param("providerId"))?; let app_type = app_type .or_else(|| app.as_deref().map(|s| s.into())) .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - let (api_key, base_url, usage_script_code, timeout) = { - let config = state - .config - .read() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config.get_manager(&app_type).ok_or("应用类型不存在")?; - - let provider = manager.providers.get(&provider_id).ok_or("供应商不存在")?; - - let usage_script = provider - .meta - .as_ref() - .and_then(|m| m.usage_script.as_ref()) - .ok_or("未配置用量查询脚本")?; - - if !usage_script.enabled { - return Err("用量查询未启用".to_string()); - } - - let (api_key, base_url) = extract_credentials(provider, &app_type)?; - let timeout = usage_script.timeout.unwrap_or(10); - let code = usage_script.code.clone(); - - drop(config); - - (api_key, base_url, code, timeout) - }; - - let result = - crate::usage_script::execute_usage_script(&usage_script_code, &api_key, &base_url, timeout) - .await; - - match result { - Ok(data) => { - let usage_list: Vec = if data.is_array() { - serde_json::from_value(data).map_err(|e| format!("数据格式错误: {}", e))? - } else { - let single: UsageData = - serde_json::from_value(data).map_err(|e| format!("数据格式错误: {}", e))?; - vec![single] - }; - - Ok(UsageResult { - success: true, - data: Some(usage_list), - error: None, - }) - } - Err(e) => Ok(UsageResult { - success: false, - data: None, - error: Some(e.to_string()), - }), - } -} - -fn extract_credentials( - provider: &crate::provider::Provider, - app_type: &AppType, -) -> Result<(String, String), String> { - match app_type { - AppType::Claude => { - let env = provider - .settings_config - .get("env") - .and_then(|v| v.as_object()) - .ok_or("配置格式错误: 缺少 env")?; - - let api_key = env - .get("ANTHROPIC_AUTH_TOKEN") - .and_then(|v| v.as_str()) - .ok_or("缺少 API Key")? - .to_string(); - - let base_url = env - .get("ANTHROPIC_BASE_URL") - .and_then(|v| v.as_str()) - .ok_or("缺少 ANTHROPIC_BASE_URL 配置")? - .to_string(); - - Ok((api_key, base_url)) - } - AppType::Codex => { - let auth = provider - .settings_config - .get("auth") - .and_then(|v| v.as_object()) - .ok_or("配置格式错误: 缺少 auth")?; - - let api_key = auth - .get("OPENAI_API_KEY") - .and_then(|v| v.as_str()) - .ok_or("缺少 API Key")? - .to_string(); - - let config_toml = provider - .settings_config - .get("config") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - let base_url = if config_toml.contains("base_url") { - let re = regex::Regex::new(r#"base_url\s*=\s*["']([^"']+)["']"#).unwrap(); - re.captures(config_toml) - .and_then(|caps| caps.get(1)) - .map(|m| m.as_str().to_string()) - .ok_or("config.toml 中 base_url 格式错误")? - } else { - return Err("config.toml 中缺少 base_url 配置".to_string()); - }; - - Ok((api_key, base_url)) - } - } + ProviderService::query_usage(state.inner(), app_type, &provider_id) + .await + .map_err(|e| e.to_string()) } /// 读取当前生效的配置内容 @@ -545,25 +205,7 @@ pub async fn read_live_provider_settings( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - match app_type { - AppType::Codex => { - let auth_path = crate::codex_config::get_codex_auth_path(); - if !auth_path.exists() { - return Err("Codex 配置文件不存在:缺少 auth.json".to_string()); - } - let auth: serde_json::Value = crate::config::read_json_file(&auth_path)?; - let cfg_text = crate::codex_config::read_and_validate_codex_config_text()?; - Ok(serde_json::json!({ "auth": auth, "config": cfg_text })) - } - AppType::Claude => { - let path = crate::config::get_claude_settings_path(); - if !path.exists() { - return Err("Claude Code 配置文件不存在".to_string()); - } - let v: serde_json::Value = crate::config::read_json_file(&path)?; - Ok(v) - } - } + ProviderService::read_live_settings(app_type).map_err(|e| e.to_string()) } /// 测试第三方/自定义供应商端点的网络延迟 @@ -593,28 +235,9 @@ pub async fn get_custom_endpoints( .unwrap_or(AppType::Claude); let provider_id = provider_id .or(providerId) - .ok_or_else(|| "缺少 providerId".to_string())?; - let mut cfg_guard = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = cfg_guard - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - let Some(provider) = manager.providers.get_mut(&provider_id) else { - return Ok(vec![]); - }; - - let meta = provider.meta.get_or_insert_with(ProviderMeta::default); - if !meta.custom_endpoints.is_empty() { - let mut result: Vec<_> = meta.custom_endpoints.values().cloned().collect(); - result.sort_by(|a, b| b.added_at.cmp(&a.added_at)); - return Ok(result); - } - - Ok(vec![]) + .ok_or_else(|| missing_param("providerId"))?; + ProviderService::get_custom_endpoints(state.inner(), app_type, &provider_id) + .map_err(|e| e.to_string()) } /// 添加自定义端点 @@ -634,39 +257,9 @@ pub async fn add_custom_endpoint( .unwrap_or(AppType::Claude); let provider_id = provider_id .or(providerId) - .ok_or_else(|| "缺少 providerId".to_string())?; - let normalized = url.trim().trim_end_matches('/').to_string(); - if normalized.is_empty() { - return Err("URL 不能为空".to_string()); - } - - let mut cfg_guard = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = cfg_guard - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - let Some(provider) = manager.providers.get_mut(&provider_id) else { - return Err("供应商不存在或未选择".to_string()); - }; - let meta = provider.meta.get_or_insert_with(ProviderMeta::default); - - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as i64; - - let endpoint = crate::settings::CustomEndpoint { - url: normalized.clone(), - added_at: timestamp, - last_used: None, - }; - meta.custom_endpoints.insert(normalized, endpoint); - drop(cfg_guard); - state.save()?; - Ok(()) + .ok_or_else(|| missing_param("providerId"))?; + ProviderService::add_custom_endpoint(state.inner(), app_type, &provider_id, url) + .map_err(|e| e.to_string()) } /// 删除自定义端点 @@ -686,25 +279,9 @@ pub async fn remove_custom_endpoint( .unwrap_or(AppType::Claude); let provider_id = provider_id .or(providerId) - .ok_or_else(|| "缺少 providerId".to_string())?; - let normalized = url.trim().trim_end_matches('/').to_string(); - - let mut cfg_guard = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = cfg_guard - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - if let Some(provider) = manager.providers.get_mut(&provider_id) { - if let Some(meta) = provider.meta.as_mut() { - meta.custom_endpoints.remove(&normalized); - } - } - drop(cfg_guard); - state.save()?; - Ok(()) + .ok_or_else(|| missing_param("providerId"))?; + ProviderService::remove_custom_endpoint(state.inner(), app_type, &provider_id, url) + .map_err(|e| e.to_string()) } /// 更新端点最后使用时间 @@ -724,38 +301,9 @@ pub async fn update_endpoint_last_used( .unwrap_or(AppType::Claude); let provider_id = provider_id .or(providerId) - .ok_or_else(|| "缺少 providerId".to_string())?; - let normalized = url.trim().trim_end_matches('/').to_string(); - - let mut cfg_guard = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - let manager = cfg_guard - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - if let Some(provider) = manager.providers.get_mut(&provider_id) { - if let Some(meta) = provider.meta.as_mut() { - if let Some(endpoint) = meta.custom_endpoints.get_mut(&normalized) { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as i64; - endpoint.last_used = Some(timestamp); - } - } - } - drop(cfg_guard); - state.save()?; - Ok(()) -} - -#[derive(Deserialize)] -pub struct ProviderSortUpdate { - pub id: String, - #[serde(rename = "sortIndex")] - pub sort_index: usize, + .ok_or_else(|| missing_param("providerId"))?; + ProviderService::update_endpoint_last_used(state.inner(), app_type, &provider_id, url) + .map_err(|e| e.to_string()) } /// 更新多个供应商的排序 @@ -772,23 +320,5 @@ pub async fn update_providers_sort_order( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - let mut config = state - .config - .write() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - for update in updates { - if let Some(provider) = manager.providers.get_mut(&update.id) { - provider.sort_index = Some(update.sort_index); - } - } - - drop(config); - state.save()?; - - Ok(true) + ProviderService::update_sort_order(state.inner(), app_type, updates).map_err(|e| e.to_string()) } diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 8fc4f7c..62d68e6 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -46,6 +46,12 @@ pub enum AppError { McpValidation(String), #[error("{0}")] Message(String), + #[error("{zh} ({en})")] + Localized { + key: &'static str, + zh: String, + en: String, + }, } impl AppError { @@ -69,6 +75,14 @@ impl AppError { source, } } + + pub fn localized(key: &'static str, zh: impl Into, en: impl Into) -> Self { + Self::Localized { + key, + zh: zh.into(), + en: en.into(), + } + } } impl From> for AppError { diff --git a/src-tauri/src/services/mod.rs b/src-tauri/src/services/mod.rs index 1e64211..9efe214 100644 --- a/src-tauri/src/services/mod.rs +++ b/src-tauri/src/services/mod.rs @@ -5,5 +5,5 @@ pub mod speedtest; pub use config::ConfigService; pub use mcp::McpService; -pub use provider::ProviderService; +pub use provider::{ProviderService, ProviderSortUpdate}; pub use speedtest::{EndpointLatency, SpeedtestService}; diff --git a/src-tauri/src/services/provider.rs b/src-tauri/src/services/provider.rs index bd8830b..f8987bd 100644 --- a/src-tauri/src/services/provider.rs +++ b/src-tauri/src/services/provider.rs @@ -1,60 +1,753 @@ +use regex::Regex; +use serde::Deserialize; use serde_json::{json, Value}; +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; use crate::app_config::{AppType, MultiAppConfig}; use crate::codex_config::{get_codex_auth_path, get_codex_config_path, write_codex_live_atomic}; use crate::config::{ delete_file, get_claude_settings_path, get_provider_config_path, read_json_file, - write_json_file, + write_json_file, write_text_file, }; use crate::error::AppError; use crate::mcp; +use crate::provider::{Provider, ProviderMeta, UsageData, UsageResult}; +use crate::settings::CustomEndpoint; +use crate::store::AppState; +use crate::usage_script; /// 供应商相关业务逻辑 pub struct ProviderService; +#[derive(Clone)] +enum LiveSnapshot { + Claude { + settings: Option, + }, + Codex { + auth: Option, + config: Option, + }, +} + +#[derive(Clone)] +struct PostCommitAction { + app_type: AppType, + provider: Provider, + backup: LiveSnapshot, + sync_mcp: bool, + refresh_snapshot: bool, +} + +impl LiveSnapshot { + fn restore(&self) -> Result<(), AppError> { + match self { + LiveSnapshot::Claude { settings } => { + let path = get_claude_settings_path(); + if let Some(value) = settings { + write_json_file(&path, value)?; + } else if path.exists() { + delete_file(&path)?; + } + } + LiveSnapshot::Codex { auth, config } => { + let auth_path = get_codex_auth_path(); + let config_path = get_codex_config_path(); + if let Some(value) = auth { + write_json_file(&auth_path, value)?; + } else if auth_path.exists() { + delete_file(&auth_path)?; + } + + if let Some(text) = config { + write_text_file(&config_path, text)?; + } else if config_path.exists() { + delete_file(&config_path)?; + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_provider_settings_rejects_missing_auth() { + let provider = Provider::with_id( + "codex".into(), + "Codex".into(), + json!({ "config": "base_url = \"https://example.com\"" }), + None, + ); + let err = ProviderService::validate_provider_settings(&AppType::Codex, &provider) + .expect_err("missing auth should be rejected"); + assert!( + err.to_string().contains("auth"), + "expected auth error, got {err:?}" + ); + } + + #[test] + fn extract_credentials_returns_expected_values() { + let provider = Provider::with_id( + "claude".into(), + "Claude".into(), + json!({ + "env": { + "ANTHROPIC_AUTH_TOKEN": "token", + "ANTHROPIC_BASE_URL": "https://claude.example" + } + }), + None, + ); + let (api_key, base_url) = + ProviderService::extract_credentials(&provider, &AppType::Claude).unwrap(); + assert_eq!(api_key, "token"); + assert_eq!(base_url, "https://claude.example"); + } +} + impl ProviderService { - /// 切换指定应用的供应商 - pub fn switch( - config: &mut MultiAppConfig, - app_type: AppType, + fn run_transaction(state: &AppState, f: F) -> Result + where + F: FnOnce(&mut MultiAppConfig) -> Result<(R, Option), AppError>, + { + let mut guard = state.config.write().map_err(AppError::from)?; + let original = guard.clone(); + let (result, action) = match f(&mut guard) { + Ok(value) => value, + Err(err) => { + *guard = original; + return Err(err); + } + }; + drop(guard); + + if let Err(save_err) = state.save() { + if let Err(rollback_err) = Self::restore_config_only(state, original.clone()) { + return Err(AppError::localized( + "config.save.rollback_failed", + format!("保存配置失败: {};回滚失败: {}", save_err, rollback_err), + format!( + "Failed to save config: {}; rollback failed: {}", + save_err, rollback_err + ), + )); + } + return Err(save_err); + } + + if let Some(action) = action { + if let Err(err) = Self::apply_post_commit(state, &action) { + if let Err(rollback_err) = + Self::rollback_after_failure(state, original.clone(), action.backup.clone()) + { + return Err(AppError::localized( + "post_commit.rollback_failed", + format!("后置操作失败: {};回滚失败: {}", err, rollback_err), + format!( + "Post-commit step failed: {}; rollback failed: {}", + err, rollback_err + ), + )); + } + return Err(err); + } + } + + Ok(result) + } + + fn restore_config_only(state: &AppState, snapshot: MultiAppConfig) -> Result<(), AppError> { + { + let mut guard = state.config.write().map_err(AppError::from)?; + *guard = snapshot; + } + state.save() + } + + fn rollback_after_failure( + state: &AppState, + snapshot: MultiAppConfig, + backup: LiveSnapshot, + ) -> Result<(), AppError> { + Self::restore_config_only(state, snapshot)?; + backup.restore() + } + + fn apply_post_commit(state: &AppState, action: &PostCommitAction) -> Result<(), AppError> { + Self::write_live_snapshot(&action.app_type, &action.provider)?; + if action.sync_mcp { + let config_clone = { + let guard = state.config.read().map_err(AppError::from)?; + guard.clone() + }; + mcp::sync_enabled_to_codex(&config_clone)?; + } + if action.refresh_snapshot { + Self::refresh_provider_snapshot(state, &action.app_type, &action.provider.id)?; + } + Ok(()) + } + + fn refresh_provider_snapshot( + state: &AppState, + app_type: &AppType, provider_id: &str, ) -> Result<(), AppError> { match app_type { - AppType::Codex => Self::switch_codex(config, provider_id), - AppType::Claude => Self::switch_claude(config, provider_id), + AppType::Claude => { + let settings_path = get_claude_settings_path(); + if !settings_path.exists() { + return Err(AppError::localized( + "claude.live.missing", + "Claude 设置文件不存在,无法刷新快照", + "Claude settings file missing; cannot refresh snapshot", + )); + } + let live_after = read_json_file::(&settings_path)?; + { + let mut guard = state.config.write().map_err(AppError::from)?; + if let Some(manager) = guard.get_manager_mut(app_type) { + if let Some(target) = manager.providers.get_mut(provider_id) { + target.settings_config = live_after; + } + } + } + state.save()?; + } + AppType::Codex => { + let auth_path = get_codex_auth_path(); + if !auth_path.exists() { + return Err(AppError::localized( + "codex.live.missing", + "Codex auth.json 不存在,无法刷新快照", + "Codex auth.json missing; cannot refresh snapshot", + )); + } + let auth: Value = read_json_file(&auth_path)?; + let cfg_text = crate::codex_config::read_and_validate_codex_config_text()?; + + { + let mut guard = state.config.write().map_err(AppError::from)?; + if let Some(manager) = guard.get_manager_mut(app_type) { + if let Some(target) = manager.providers.get_mut(provider_id) { + let obj = target.settings_config.as_object_mut().ok_or_else(|| { + AppError::Config(format!( + "供应商 {} 的 Codex 配置必须是 JSON 对象", + provider_id + )) + })?; + obj.insert("auth".to_string(), auth.clone()); + obj.insert("config".to_string(), Value::String(cfg_text.clone())); + } + } + } + state.save()?; + } + } + Ok(()) + } + + fn capture_live_snapshot(app_type: &AppType) -> Result { + match app_type { + AppType::Claude => { + let path = get_claude_settings_path(); + let settings = if path.exists() { + Some(read_json_file::(&path)?) + } else { + None + }; + Ok(LiveSnapshot::Claude { settings }) + } + AppType::Codex => { + let auth_path = get_codex_auth_path(); + let config_path = get_codex_config_path(); + let auth = if auth_path.exists() { + Some(read_json_file::(&auth_path)?) + } else { + None + }; + let config = if config_path.exists() { + Some( + std::fs::read_to_string(&config_path) + .map_err(|e| AppError::io(&config_path, e))?, + ) + } else { + None + }; + Ok(LiveSnapshot::Codex { auth, config }) + } } } - fn switch_codex(config: &mut MultiAppConfig, provider_id: &str) -> Result<(), AppError> { + /// 列出指定应用下的所有供应商 + pub fn list( + state: &AppState, + app_type: AppType, + ) -> Result, AppError> { + let config = state.config.read().map_err(AppError::from)?; + let manager = config + .get_manager(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + Ok(manager.get_all_providers().clone()) + } + + /// 获取当前供应商 ID + pub fn current(state: &AppState, app_type: AppType) -> Result { + let config = state.config.read().map_err(AppError::from)?; + let manager = config + .get_manager(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + Ok(manager.current.clone()) + } + + /// 新增供应商 + pub fn add(state: &AppState, app_type: AppType, provider: Provider) -> Result { + Self::validate_provider_settings(&app_type, &provider)?; + + let app_type_clone = app_type.clone(); + let provider_clone = provider.clone(); + + Self::run_transaction(state, move |config| { + config.ensure_app(&app_type_clone); + let manager = config + .get_manager_mut(&app_type_clone) + .ok_or_else(|| Self::app_not_found(&app_type_clone))?; + + let is_current = manager.current == provider_clone.id; + manager + .providers + .insert(provider_clone.id.clone(), provider_clone.clone()); + + let action = if is_current { + let backup = Self::capture_live_snapshot(&app_type_clone)?; + Some(PostCommitAction { + app_type: app_type_clone.clone(), + provider: provider_clone.clone(), + backup, + sync_mcp: false, + refresh_snapshot: false, + }) + } else { + None + }; + + Ok((true, action)) + }) + } + + /// 更新供应商 + pub fn update( + state: &AppState, + app_type: AppType, + provider: Provider, + ) -> Result { + Self::validate_provider_settings(&app_type, &provider)?; + let provider_id = provider.id.clone(); + let app_type_clone = app_type.clone(); + let provider_clone = provider.clone(); + + Self::run_transaction(state, move |config| { + let manager = config + .get_manager_mut(&app_type_clone) + .ok_or_else(|| Self::app_not_found(&app_type_clone))?; + + if !manager.providers.contains_key(&provider_id) { + return Err(AppError::ProviderNotFound(provider_id.clone())); + } + + let is_current = manager.current == provider_id; + let merged = if let Some(existing) = manager.providers.get(&provider_id) { + let mut updated = provider_clone.clone(); + match (existing.meta.as_ref(), updated.meta.take()) { + (Some(old_meta), None) => { + updated.meta = Some(old_meta.clone()); + } + (Some(old_meta), Some(mut new_meta)) => { + let mut merged_map = old_meta.custom_endpoints.clone(); + for (url, ep) in new_meta.custom_endpoints.drain() { + merged_map.entry(url).or_insert(ep); + } + updated.meta = Some(ProviderMeta { + custom_endpoints: merged_map, + usage_script: new_meta.usage_script.clone(), + }); + } + (None, maybe_new) => { + updated.meta = maybe_new; + } + } + updated + } else { + provider_clone.clone() + }; + + manager.providers.insert(provider_id.clone(), merged); + + let action = if is_current { + let backup = Self::capture_live_snapshot(&app_type_clone)?; + Some(PostCommitAction { + app_type: app_type_clone.clone(), + provider: provider_clone.clone(), + backup, + sync_mcp: false, + refresh_snapshot: false, + }) + } else { + None + }; + + Ok((true, action)) + }) + } + + /// 导入当前 live 配置为默认供应商 + pub fn import_default_config(state: &AppState, app_type: AppType) -> Result<(), AppError> { + { + let config = state.config.read().map_err(AppError::from)?; + if let Some(manager) = config.get_manager(&app_type) { + if !manager.get_all_providers().is_empty() { + return Ok(()); + } + } + } + + let settings_config = match app_type { + AppType::Codex => { + let auth_path = get_codex_auth_path(); + if !auth_path.exists() { + return Err(AppError::localized( + "codex.live.missing", + "Codex 配置文件不存在", + "Codex configuration file is missing", + )); + } + let auth: Value = read_json_file(&auth_path)?; + let config_str = crate::codex_config::read_and_validate_codex_config_text()?; + json!({ "auth": auth, "config": config_str }) + } + AppType::Claude => { + let settings_path = get_claude_settings_path(); + if !settings_path.exists() { + return Err(AppError::localized( + "claude.live.missing", + "Claude Code 配置文件不存在", + "Claude settings file is missing", + )); + } + read_json_file(&settings_path)? + } + }; + + let provider = Provider::with_id( + "default".to_string(), + "default".to_string(), + settings_config, + None, + ); + + { + let mut config = state.config.write().map_err(AppError::from)?; + let manager = config + .get_manager_mut(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + manager + .providers + .insert(provider.id.clone(), provider.clone()); + manager.current = provider.id.clone(); + } + + state.save()?; + Ok(()) + } + + /// 读取当前 live 配置 + pub fn read_live_settings(app_type: AppType) -> Result { + match app_type { + AppType::Codex => { + let auth_path = get_codex_auth_path(); + if !auth_path.exists() { + return Err(AppError::localized( + "codex.auth.missing", + "Codex 配置文件不存在:缺少 auth.json", + "Codex configuration missing: auth.json not found", + )); + } + let auth: Value = read_json_file(&auth_path)?; + let cfg_text = crate::codex_config::read_and_validate_codex_config_text()?; + Ok(json!({ "auth": auth, "config": cfg_text })) + } + AppType::Claude => { + let path = get_claude_settings_path(); + if !path.exists() { + return Err(AppError::localized( + "claude.live.missing", + "Claude Code 配置文件不存在", + "Claude settings file is missing", + )); + } + read_json_file(&path) + } + } + } + + /// 获取自定义端点列表 + pub fn get_custom_endpoints( + state: &AppState, + app_type: AppType, + provider_id: &str, + ) -> Result, AppError> { + let cfg = state.config.read().map_err(AppError::from)?; + let manager = cfg + .get_manager(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + + let Some(provider) = manager.providers.get(provider_id) else { + return Ok(vec![]); + }; + let Some(meta) = provider.meta.as_ref() else { + return Ok(vec![]); + }; + if meta.custom_endpoints.is_empty() { + return Ok(vec![]); + } + + let mut result: Vec<_> = meta.custom_endpoints.values().cloned().collect(); + result.sort_by(|a, b| b.added_at.cmp(&a.added_at)); + Ok(result) + } + + /// 新增自定义端点 + pub fn add_custom_endpoint( + state: &AppState, + app_type: AppType, + provider_id: &str, + url: String, + ) -> Result<(), AppError> { + let normalized = url.trim().trim_end_matches('/').to_string(); + if normalized.is_empty() { + return Err(AppError::localized( + "provider.endpoint.url_required", + "URL 不能为空", + "URL cannot be empty", + )); + } + + { + let mut cfg = state.config.write().map_err(AppError::from)?; + let manager = cfg + .get_manager_mut(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + let provider = manager + .providers + .get_mut(provider_id) + .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))?; + let meta = provider.meta.get_or_insert_with(ProviderMeta::default); + + let endpoint = CustomEndpoint { + url: normalized.clone(), + added_at: Self::now_millis(), + last_used: None, + }; + meta.custom_endpoints.insert(normalized, endpoint); + } + + state.save()?; + Ok(()) + } + + /// 删除自定义端点 + pub fn remove_custom_endpoint( + state: &AppState, + app_type: AppType, + provider_id: &str, + url: String, + ) -> Result<(), AppError> { + let normalized = url.trim().trim_end_matches('/').to_string(); + + { + let mut cfg = state.config.write().map_err(AppError::from)?; + if let Some(manager) = cfg.get_manager_mut(&app_type) { + if let Some(provider) = manager.providers.get_mut(provider_id) { + if let Some(meta) = provider.meta.as_mut() { + meta.custom_endpoints.remove(&normalized); + } + } + } + } + + state.save()?; + Ok(()) + } + + /// 更新端点最后使用时间 + pub fn update_endpoint_last_used( + state: &AppState, + app_type: AppType, + provider_id: &str, + url: String, + ) -> Result<(), AppError> { + let normalized = url.trim().trim_end_matches('/').to_string(); + + { + let mut cfg = state.config.write().map_err(AppError::from)?; + if let Some(manager) = cfg.get_manager_mut(&app_type) { + if let Some(provider) = manager.providers.get_mut(provider_id) { + if let Some(meta) = provider.meta.as_mut() { + if let Some(endpoint) = meta.custom_endpoints.get_mut(&normalized) { + endpoint.last_used = Some(Self::now_millis()); + } + } + } + } + } + + state.save()?; + Ok(()) + } + + /// 更新供应商排序 + pub fn update_sort_order( + state: &AppState, + app_type: AppType, + updates: Vec, + ) -> Result { + { + let mut cfg = state.config.write().map_err(AppError::from)?; + let manager = cfg + .get_manager_mut(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + + for update in updates { + if let Some(provider) = manager.providers.get_mut(&update.id) { + provider.sort_index = Some(update.sort_index); + } + } + } + + state.save()?; + Ok(true) + } + + /// 查询供应商用量 + pub async fn query_usage( + state: &AppState, + app_type: AppType, + provider_id: &str, + ) -> Result { + let (provider, script_code, timeout) = { + let config = state.config.read().map_err(AppError::from)?; + let manager = config + .get_manager(&app_type) + .ok_or_else(|| Self::app_not_found(&app_type))?; + let provider = manager + .providers + .get(provider_id) + .cloned() + .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))?; + let (script_code, timeout) = { + let usage_script = provider + .meta + .as_ref() + .and_then(|m| m.usage_script.as_ref()) + .ok_or_else(|| { + AppError::localized( + "provider.usage.script.missing", + "未配置用量查询脚本", + "Usage script is not configured", + ) + })?; + if !usage_script.enabled { + return Err(AppError::localized( + "provider.usage.disabled", + "用量查询未启用", + "Usage query is disabled", + )); + } + ( + usage_script.code.clone(), + usage_script.timeout.unwrap_or(10), + ) + }; + + (provider, script_code, timeout) + }; + + let (api_key, base_url) = Self::extract_credentials(&provider, &app_type)?; + + match usage_script::execute_usage_script(&script_code, &api_key, &base_url, timeout).await { + Ok(data) => { + let usage_list: Vec = if data.is_array() { + serde_json::from_value(data) + .map_err(|e| AppError::Message(format!("数据格式错误: {}", e)))? + } else { + let single: UsageData = serde_json::from_value(data) + .map_err(|e| AppError::Message(format!("数据格式错误: {}", e)))?; + vec![single] + }; + + Ok(UsageResult { + success: true, + data: Some(usage_list), + error: None, + }) + } + Err(err) => Ok(UsageResult { + success: false, + data: None, + error: Some(err.to_string()), + }), + } + } + + /// 切换指定应用的供应商 + pub fn switch(state: &AppState, app_type: AppType, provider_id: &str) -> Result<(), AppError> { + let app_type_clone = app_type.clone(); + let provider_id_owned = provider_id.to_string(); + + Self::run_transaction(state, move |config| { + let backup = Self::capture_live_snapshot(&app_type_clone)?; + let provider = match app_type_clone { + AppType::Codex => Self::prepare_switch_codex(config, &provider_id_owned)?, + AppType::Claude => Self::prepare_switch_claude(config, &provider_id_owned)?, + }; + + let action = PostCommitAction { + app_type: app_type_clone.clone(), + provider, + backup, + sync_mcp: matches!(app_type_clone, AppType::Codex), + refresh_snapshot: true, + }; + + Ok(((), Some(action))) + }) + } + + fn prepare_switch_codex( + config: &mut MultiAppConfig, + provider_id: &str, + ) -> Result { let provider = config .get_manager(&AppType::Codex) - .ok_or_else(|| AppError::Message("应用类型不存在: Codex".into()))? + .ok_or_else(|| Self::app_not_found(&AppType::Codex))? .providers .get(provider_id) .cloned() .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))?; Self::backfill_codex_current(config, provider_id)?; - Self::write_codex_live(&provider)?; if let Some(manager) = config.get_manager_mut(&AppType::Codex) { manager.current = provider_id.to_string(); } - // 同步启用的 MCP 服务器 - mcp::sync_enabled_to_codex(config)?; - - // 更新持久化快照 - let cfg_text_after = crate::codex_config::read_and_validate_codex_config_text()?; - if let Some(manager) = config.get_manager_mut(&AppType::Codex) { - if let Some(target) = manager.providers.get_mut(provider_id) { - if let Some(obj) = target.settings_config.as_object_mut() { - obj.insert("config".to_string(), Value::String(cfg_text_after)); - } - } - } - - Ok(()) + Ok(provider) } fn backfill_codex_current( @@ -97,7 +790,7 @@ impl ProviderService { Ok(()) } - fn write_codex_live(provider: &crate::provider::Provider) -> Result<(), AppError> { + fn write_codex_live(provider: &Provider) -> Result<(), AppError> { let settings = provider .settings_config .as_object() @@ -117,32 +810,25 @@ impl ProviderService { Ok(()) } - fn switch_claude(config: &mut MultiAppConfig, provider_id: &str) -> Result<(), AppError> { - let provider = { - let manager = config - .get_manager(&AppType::Claude) - .ok_or_else(|| AppError::Message("应用类型不存在: Claude".into()))?; - manager - .providers - .get(provider_id) - .cloned() - .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))? - }; + fn prepare_switch_claude( + config: &mut MultiAppConfig, + provider_id: &str, + ) -> Result { + let provider = config + .get_manager(&AppType::Claude) + .ok_or_else(|| Self::app_not_found(&AppType::Claude))? + .providers + .get(provider_id) + .cloned() + .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))?; Self::backfill_claude_current(config, provider_id)?; - Self::write_claude_live(&provider)?; if let Some(manager) = config.get_manager_mut(&AppType::Claude) { manager.current = provider_id.to_string(); - - if let Some(target) = manager.providers.get_mut(provider_id) { - let settings_path = get_claude_settings_path(); - let live_after = read_json_file::(&settings_path)?; - target.settings_config = live_after; - } } - Ok(()) + Ok(provider) } fn backfill_claude_current( @@ -172,7 +858,7 @@ impl ProviderService { Ok(()) } - fn write_claude_live(provider: &crate::provider::Provider) -> Result<(), AppError> { + fn write_claude_live(provider: &Provider) -> Result<(), AppError> { let settings_path = get_claude_settings_path(); if let Some(parent) = settings_path.parent() { std::fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?; @@ -182,6 +868,181 @@ impl ProviderService { Ok(()) } + fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Result<(), AppError> { + match app_type { + AppType::Codex => Self::write_codex_live(provider), + AppType::Claude => Self::write_claude_live(provider), + } + } + + fn validate_provider_settings(app_type: &AppType, provider: &Provider) -> Result<(), AppError> { + match app_type { + AppType::Claude => { + if !provider.settings_config.is_object() { + return Err(AppError::localized( + "provider.claude.settings.not_object", + "Claude 配置必须是 JSON 对象", + "Claude configuration must be a JSON object", + )); + } + } + AppType::Codex => { + let settings = provider.settings_config.as_object().ok_or_else(|| { + AppError::localized( + "provider.codex.settings.not_object", + "Codex 配置必须是 JSON 对象", + "Codex configuration must be a JSON object", + ) + })?; + + let auth = settings.get("auth").ok_or_else(|| { + AppError::localized( + "provider.codex.auth.missing", + format!("供应商 {} 缺少 auth 配置", provider.id), + format!("Provider {} is missing auth configuration", provider.id), + ) + })?; + if !auth.is_object() { + return Err(AppError::localized( + "provider.codex.auth.not_object", + format!("供应商 {} 的 auth 配置必须是 JSON 对象", provider.id), + format!( + "Provider {} auth configuration must be a JSON object", + provider.id + ), + )); + } + + if let Some(config_value) = settings.get("config") { + if !(config_value.is_string() || config_value.is_null()) { + return Err(AppError::localized( + "provider.codex.config.invalid_type", + "Codex config 字段必须是字符串", + "Codex config field must be a string", + )); + } + if let Some(cfg_text) = config_value.as_str() { + crate::codex_config::validate_config_toml(cfg_text)?; + } + } + } + } + + Ok(()) + } + + fn extract_credentials( + provider: &Provider, + app_type: &AppType, + ) -> Result<(String, String), AppError> { + match app_type { + AppType::Claude => { + let env = provider + .settings_config + .get("env") + .and_then(|v| v.as_object()) + .ok_or_else(|| { + AppError::localized( + "provider.claude.env.missing", + "配置格式错误: 缺少 env", + "Invalid configuration: missing env section", + ) + })?; + + let api_key = env + .get("ANTHROPIC_AUTH_TOKEN") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AppError::localized( + "provider.claude.api_key.missing", + "缺少 API Key", + "API key is missing", + ) + })? + .to_string(); + + let base_url = env + .get("ANTHROPIC_BASE_URL") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AppError::localized( + "provider.claude.base_url.missing", + "缺少 ANTHROPIC_BASE_URL 配置", + "Missing ANTHROPIC_BASE_URL configuration", + ) + })? + .to_string(); + + Ok((api_key, base_url)) + } + AppType::Codex => { + let auth = provider + .settings_config + .get("auth") + .and_then(|v| v.as_object()) + .ok_or_else(|| { + AppError::localized( + "provider.codex.auth.missing", + "配置格式错误: 缺少 auth", + "Invalid configuration: missing auth section", + ) + })?; + + let api_key = auth + .get("OPENAI_API_KEY") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AppError::localized( + "provider.codex.api_key.missing", + "缺少 API Key", + "API key is missing", + ) + })? + .to_string(); + + let config_toml = provider + .settings_config + .get("config") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let base_url = if config_toml.contains("base_url") { + let re = Regex::new(r#"base_url\s*=\s*["']([^"']+)["']"#) + .map_err(|e| AppError::Message(format!("正则初始化失败: {}", e)))?; + re.captures(config_toml) + .and_then(|caps| caps.get(1)) + .map(|m| m.as_str().to_string()) + .ok_or_else(|| { + AppError::localized( + "provider.codex.base_url.invalid", + "config.toml 中 base_url 格式错误", + "base_url in config.toml has invalid format", + ) + })? + } else { + return Err(AppError::localized( + "provider.codex.base_url.missing", + "config.toml 中缺少 base_url 配置", + "base_url is missing from config.toml", + )); + }; + + Ok((api_key, base_url)) + } + } + } + + fn app_not_found(app_type: &AppType) -> AppError { + AppError::Message(format!("应用类型不存在: {:?}", app_type)) + } + + fn now_millis() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as i64 + } + pub fn delete( config: &mut MultiAppConfig, app_type: AppType, @@ -192,12 +1053,16 @@ impl ProviderService { .map(|m| m.current == provider_id) .unwrap_or(false); if current_matches { - return Err(AppError::Config("不能删除当前正在使用的供应商".into())); + return Err(AppError::localized( + "provider.delete.current", + "不能删除当前正在使用的供应商", + "Cannot delete the provider currently in use", + )); } let provider = config .get_manager(&app_type) - .ok_or_else(|| AppError::Message(format!("应用类型不存在: {:?}", app_type)))? + .ok_or_else(|| Self::app_not_found(&app_type))? .providers .get(provider_id) .cloned() @@ -208,6 +1073,8 @@ impl ProviderService { crate::codex_config::delete_codex_provider_config(provider_id, &provider.name)?; } AppType::Claude => { + // 兼容旧版本:历史上会在 Claude 目录内为每个供应商生成 settings-*.json 副本 + // 这里继续清理这些遗留文件,避免堆积过期配置。 let by_name = get_provider_config_path(provider_id, Some(&provider.name)); let by_id = get_provider_config_path(provider_id, None); delete_file(&by_name)?; @@ -222,3 +1089,10 @@ impl ProviderService { Ok(()) } } + +#[derive(Debug, Clone, Deserialize)] +pub struct ProviderSortUpdate { + pub id: String, + #[serde(rename = "sortIndex")] + pub sort_index: usize, +} diff --git a/src-tauri/tests/mcp_commands.rs b/src-tauri/tests/mcp_commands.rs index f23d61a..4006bad 100644 --- a/src-tauri/tests/mcp_commands.rs +++ b/src-tauri/tests/mcp_commands.rs @@ -76,6 +76,10 @@ fn import_default_config_without_live_file_returns_error() { let err = import_default_config_test_hook(&state, AppType::Claude) .expect_err("missing live file should error"); match err { + AppError::Localized { zh, .. } => assert!( + zh.contains("Claude Code 配置文件不存在"), + "unexpected error message: {zh}" + ), AppError::Message(msg) => assert!( msg.contains("Claude Code 配置文件不存在"), "unexpected error message: {msg}" diff --git a/src-tauri/tests/provider_service.rs b/src-tauri/tests/provider_service.rs index 29a90dd..b96f20b 100644 --- a/src-tauri/tests/provider_service.rs +++ b/src-tauri/tests/provider_service.rs @@ -1,7 +1,8 @@ use serde_json::json; +use std::sync::RwLock; use cc_switch_lib::{ - get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, + get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppState, AppType, MultiAppConfig, Provider, ProviderService, }; @@ -33,9 +34,9 @@ command = "echo" write_codex_live_atomic(&legacy_auth, Some(legacy_config)) .expect("seed existing codex live config"); - let mut config = MultiAppConfig::default(); + let mut initial_config = MultiAppConfig::default(); { - let manager = config + let manager = initial_config .get_manager_mut(&AppType::Codex) .expect("codex manager"); manager.current = "old-provider".to_string(); @@ -68,7 +69,7 @@ command = "say" ); } - config.mcp.codex.servers.insert( + initial_config.mcp.codex.servers.insert( "echo-server".into(), json!({ "id": "echo-server", @@ -80,7 +81,11 @@ command = "say" }), ); - ProviderService::switch(&mut config, AppType::Codex, "new-provider") + let state = AppState { + config: RwLock::new(initial_config), + }; + + ProviderService::switch(&state, AppType::Codex, "new-provider") .expect("switch provider should succeed"); let auth_value: serde_json::Value = @@ -98,7 +103,8 @@ command = "say" "config.toml should contain synced MCP servers" ); - let manager = config + let guard = state.config.read().expect("read config after switch"); + let manager = guard .get_manager(&AppType::Codex) .expect("codex manager after switch"); assert_eq!(manager.current, "new-provider", "current provider updated"); @@ -188,7 +194,11 @@ fn provider_service_switch_claude_updates_live_and_state() { ); } - ProviderService::switch(&mut config, AppType::Claude, "new-provider") + let state = AppState { + config: RwLock::new(config), + }; + + ProviderService::switch(&state, AppType::Claude, "new-provider") .expect("switch provider should succeed"); let live_after: serde_json::Value = @@ -202,7 +212,11 @@ fn provider_service_switch_claude_updates_live_and_state() { "live settings.json should reflect new provider auth" ); - let manager = config + let guard = state + .config + .read() + .expect("read claude config after switch"); + let manager = guard .get_manager(&AppType::Claude) .expect("claude manager after switch"); assert_eq!(manager.current, "new-provider", "current provider updated"); @@ -219,9 +233,11 @@ fn provider_service_switch_claude_updates_live_and_state() { #[test] fn provider_service_switch_missing_provider_returns_error() { - let mut config = MultiAppConfig::default(); + let state = AppState { + config: RwLock::new(MultiAppConfig::default()), + }; - let err = ProviderService::switch(&mut config, AppType::Claude, "missing") + let err = ProviderService::switch(&state, AppType::Claude, "missing") .expect_err("switching missing provider should fail"); match err { AppError::ProviderNotFound(id) => assert_eq!(id, "missing"), @@ -249,7 +265,11 @@ fn provider_service_switch_codex_missing_auth_returns_error() { ); } - let err = ProviderService::switch(&mut config, AppType::Codex, "invalid") + let state = AppState { + config: RwLock::new(config), + }; + + let err = ProviderService::switch(&state, AppType::Codex, "invalid") .expect_err("switching should fail without auth"); match err { AppError::Config(msg) => assert!( @@ -404,6 +424,10 @@ fn provider_service_delete_current_provider_returns_error() { let err = ProviderService::delete(&mut config, AppType::Claude, "keep") .expect_err("deleting current provider should fail"); match err { + AppError::Localized { zh, .. } => assert!( + zh.contains("不能删除当前正在使用的供应商"), + "unexpected message: {zh}" + ), AppError::Config(msg) => assert!( msg.contains("不能删除当前正在使用的供应商"), "unexpected message: {msg}"