diff --git a/docs/BACKEND_REFACTOR_PLAN.md b/docs/BACKEND_REFACTOR_PLAN.md index 4199dc1..2227789 100644 --- a/docs/BACKEND_REFACTOR_PLAN.md +++ b/docs/BACKEND_REFACTOR_PLAN.md @@ -76,7 +76,7 @@ - 已将单一 `src-tauri/src/commands.rs` 拆分为 `commands/{provider,mcp,config,settings,misc,plugin}.rs` 并通过 `commands/mod.rs` 统一导出,保持对外 API 不变。 - 每个文件聚焦单一功能域(供应商、MCP、配置、设置、杂项、插件),命令函数平均 150-250 行,可读性与后续维护性显著提升。 - 相关依赖调整后 `cargo check` 通过,静态巡检确认无重复定义或未注册命令。 -- **阶段 3:补充测试 🚧** +- **阶段 3:补充测试 ✅** - `tests/import_export_sync.rs` 集成测试涵盖配置备份、Claude/Codex live 同步、MCP 投影与 Codex/Claude 双向导入流程,并新增启用项清理、非法 TOML 抛错等失败场景验证;统一使用隔离 HOME 目录避免污染真实用户环境。 - 扩展 `lib.rs` re-export,暴露 `AppType`、`MultiAppConfig`、`AppError`、配置 IO 以及 Codex/Claude MCP 路径与同步函数,方便服务层及测试直接复用核心逻辑。 - 新增负向测试验证 Codex 供应商缺少 `auth` 字段时的错误返回,并补充备份数量上限测试;顺带修复 `create_backup` 采用内存读写避免拷贝继承旧的修改时间,确保最新备份不会在清理阶段被误删。 @@ -85,9 +85,10 @@ - 补充 Claude 切换集成测试,验证 live `settings.json` 覆写、新旧供应商快照回填以及 `.cc-switch/config.json` 持久化结果,确保阶段四提取服务层时拥有可回归的用例。 - 增加 Codex 缺失 `auth` 场景测试,确认 `switch_provider_internal` 在关键字段缺失时返回带上下文的 `AppError`,同时保持内存状态未被污染。 - 为配置导入命令抽取复用逻辑 `import_config_from_path` 并补充成功/失败集成测试,校验备份生成、状态同步、JSON 解析与文件缺失等错误回退路径;`export_config_to_file` 亦具备成功/缺失源文件的命令级回归。 - - 当前已覆盖配置、Codex/Claude MCP 核心路径及关键错误分支,后续仍需补齐命令层边界与导入导出异常回滚测试。 + - 新增 `tests/mcp_commands.rs`,通过测试钩子覆盖 `import_default_config`、`import_mcp_from_claude`、`set_mcp_enabled` 等命令层行为,验证缺失文件/非法 JSON 的错误回滚以及成功路径落盘效果;阶段三目标达成,命令层关键边界已具备回归保障。 - **阶段 4:服务层抽象 🚧** - 新增 `services/provider.rs` 并实现 `ProviderService::switch`,负责供应商切换时的业务流程(live 回填、持久化、MCP 同步),命令层通过薄封装调用并负责状态持久化。 + - 扩展 `ProviderService` 提供 `delete` 能力,统一 Codex/Claude 清理逻辑;`tests/provider_service.rs` 校验切换与删除在成功/失败场景(包括缺失供应商、缺少 auth、删除当前供应商)下的行为,确保命令/托盘复用时拥有回归护栏。 ## 渐进式重构路线 diff --git a/src-tauri/src/app_config.rs b/src-tauri/src/app_config.rs index 3393a0d..3d9295c 100644 --- a/src-tauri/src/app_config.rs +++ b/src-tauri/src/app_config.rs @@ -90,8 +90,8 @@ impl MultiAppConfig { } // 尝试读取文件 - let content = std::fs::read_to_string(&config_path) - .map_err(|e| AppError::io(&config_path, e))?; + let content = + std::fs::read_to_string(&config_path).map_err(|e| AppError::io(&config_path, e))?; // 检查是否是旧版本格式(v1) if let Ok(v1_config) = serde_json::from_str::(&content) { diff --git a/src-tauri/src/app_store.rs b/src-tauri/src/app_store.rs index d4b82b4..fe16e1f 100644 --- a/src-tauri/src/app_store.rs +++ b/src-tauri/src/app_store.rs @@ -66,7 +66,10 @@ pub fn get_app_config_dir_from_store(app: &tauri::AppHandle) -> Option Some(path) } Some(_) => { - log::warn!("Store 中的 {} 类型不正确,应为字符串", STORE_KEY_APP_CONFIG_DIR); + log::warn!( + "Store 中的 {} 类型不正确,应为字符串", + STORE_KEY_APP_CONFIG_DIR + ); None } None => None, diff --git a/src-tauri/src/claude_mcp.rs b/src-tauri/src/claude_mcp.rs index 6eff697..603a73d 100644 --- a/src-tauri/src/claude_mcp.rs +++ b/src-tauri/src/claude_mcp.rs @@ -66,8 +66,7 @@ fn read_json_value(path: &Path) -> Result { return Ok(serde_json::json!({})); } let content = fs::read_to_string(path).map_err(|e| AppError::io(path, e))?; - let value: Value = - serde_json::from_str(&content).map_err(|e| AppError::json(path, e))?; + let value: Value = serde_json::from_str(&content).map_err(|e| AppError::json(path, e))?; Ok(value) } @@ -108,9 +107,7 @@ pub fn read_mcp_json() -> Result, AppError> { pub fn upsert_mcp_server(id: &str, spec: Value) -> Result { if id.trim().is_empty() { - return Err(AppError::InvalidInput( - "MCP 服务器 ID 不能为空".into(), - )); + return Err(AppError::InvalidInput("MCP 服务器 ID 不能为空".into())); } // 基础字段校验(尽量宽松) if !spec.is_object() { @@ -179,9 +176,7 @@ pub fn upsert_mcp_server(id: &str, spec: Value) -> Result { pub fn delete_mcp_server(id: &str) -> Result { if id.trim().is_empty() { - return Err(AppError::InvalidInput( - "MCP 服务器 ID 不能为空".into(), - )); + return Err(AppError::InvalidInput("MCP 服务器 ID 不能为空".into())); } let path = user_config_path(); if !path.exists() { @@ -261,15 +256,9 @@ pub fn set_mcp_servers_map( }; if let Some(server_val) = obj.remove("server") { - let server_obj = server_val - .as_object() - .cloned() - .ok_or_else(|| { - AppError::McpValidation(format!( - "MCP 服务器 '{}' server 字段不是对象", - id - )) - })?; + let server_obj = server_val.as_object().cloned().ok_or_else(|| { + AppError::McpValidation(format!("MCP 服务器 '{}' server 字段不是对象", id)) + })?; obj = server_obj; } diff --git a/src-tauri/src/claude_plugin.rs b/src-tauri/src/claude_plugin.rs index 56e872f..08acf19 100644 --- a/src-tauri/src/claude_plugin.rs +++ b/src-tauri/src/claude_plugin.rs @@ -11,8 +11,7 @@ fn claude_dir() -> Result { if let Some(dir) = crate::settings::get_claude_override_dir() { return Ok(dir); } - let home = dirs::home_dir() - .ok_or_else(|| AppError::Config("无法获取用户主目录".into()))?; + let home = dirs::home_dir().ok_or_else(|| AppError::Config("无法获取用户主目录".into()))?; Ok(home.join(CLAUDE_DIR)) } @@ -81,8 +80,7 @@ pub fn write_claude_config() -> Result { if changed || !path.exists() { let serialized = serde_json::to_string_pretty(&obj) .map_err(|e| AppError::JsonSerialize { source: e })?; - fs::write(&path, format!("{}\n", serialized)) - .map_err(|e| AppError::io(&path, e))?; + fs::write(&path, format!("{}\n", serialized)).map_err(|e| AppError::io(&path, e))?; Ok(true) } else { Ok(false) @@ -114,8 +112,8 @@ pub fn clear_claude_config() -> Result { return Ok(false); } - let serialized = serde_json::to_string_pretty(&value) - .map_err(|e| AppError::JsonSerialize { source: e })?; + let serialized = + serde_json::to_string_pretty(&value).map_err(|e| AppError::JsonSerialize { source: e })?; fs::write(&path, format!("{}\n", serialized)).map_err(|e| AppError::io(&path, e))?; Ok(true) } diff --git a/src-tauri/src/codex_config.rs b/src-tauri/src/codex_config.rs index c1ab9b3..3d5f9c7 100644 --- a/src-tauri/src/codex_config.rs +++ b/src-tauri/src/codex_config.rs @@ -44,7 +44,10 @@ pub fn get_codex_provider_paths( } /// 删除 Codex 供应商配置文件 -pub fn delete_codex_provider_config(provider_id: &str, provider_name: &str) -> Result<(), AppError> { +pub fn delete_codex_provider_config( + provider_id: &str, + provider_name: &str, +) -> Result<(), AppError> { let (auth_path, config_path) = get_codex_provider_paths(provider_id, Some(provider_name)); delete_file(&auth_path).ok(); @@ -56,7 +59,10 @@ pub fn delete_codex_provider_config(provider_id: &str, provider_name: &str) -> R //(移除未使用的备份/保存/恢复/导入函数,避免 dead_code 告警) /// 原子写 Codex 的 `auth.json` 与 `config.toml`,在第二步失败时回滚第一步 -pub fn write_codex_live_atomic(auth: &Value, config_text_opt: Option<&str>) -> Result<(), AppError> { +pub fn write_codex_live_atomic( + auth: &Value, + config_text_opt: Option<&str>, +) -> Result<(), AppError> { let auth_path = get_codex_auth_path(); let config_path = get_codex_config_path(); @@ -70,12 +76,11 @@ pub fn write_codex_live_atomic(auth: &Value, config_text_opt: Option<&str>) -> R } else { None }; - let _old_config = - if config_path.exists() { - Some(fs::read(&config_path).map_err(|e| AppError::io(&config_path, e))?) - } else { - None - }; + let _old_config = if config_path.exists() { + Some(fs::read(&config_path).map_err(|e| AppError::io(&config_path, e))?) + } else { + None + }; // 准备写入内容 let cfg_text = match config_text_opt { @@ -83,8 +88,7 @@ pub fn write_codex_live_atomic(auth: &Value, config_text_opt: Option<&str>) -> R None => String::new(), }; if !cfg_text.trim().is_empty() { - toml::from_str::(&cfg_text) - .map_err(|e| AppError::toml(&config_path, e))?; + toml::from_str::(&cfg_text).map_err(|e| AppError::toml(&config_path, e))?; } // 第一步:写 auth.json diff --git a/src-tauri/src/commands/mcp.rs b/src-tauri/src/commands/mcp.rs index 519d9c2..7cc2ed7 100644 --- a/src-tauri/src/commands/mcp.rs +++ b/src-tauri/src/commands/mcp.rs @@ -7,6 +7,7 @@ use tauri::State; use crate::app_config::AppType; use crate::claude_mcp; +use crate::error::AppError; use crate::mcp; use crate::store::AppState; @@ -159,15 +160,8 @@ pub async fn set_mcp_enabled( id: String, enabled: bool, ) -> Result { - let mut cfg = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; let app_ty = AppType::from(app.as_deref().unwrap_or("claude")); - let changed = mcp::set_enabled_and_sync_for(&mut cfg, &app_ty, &id, enabled)?; - drop(cfg); - state.save()?; - Ok(changed) + set_mcp_enabled_internal(&*state, app_ty, &id, enabled).map_err(Into::into) } /// 手动同步:将启用的 MCP 投影到 ~/.claude.json @@ -207,10 +201,40 @@ pub async fn sync_enabled_mcp_to_codex(state: State<'_, AppState>) -> Result) -> Result { - let mut cfg = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; + import_mcp_from_claude_internal(&*state).map_err(Into::into) +} + +/// 从 ~/.codex/config.toml 导入 MCP 定义到 config.json +#[tauri::command] +pub async fn import_mcp_from_codex(state: State<'_, AppState>) -> Result { + import_mcp_from_codex_internal(&*state).map_err(Into::into) +} + +fn set_mcp_enabled_internal( + state: &AppState, + app_ty: AppType, + id: &str, + enabled: bool, +) -> Result { + let mut cfg = state.config.lock()?; + let changed = mcp::set_enabled_and_sync_for(&mut cfg, &app_ty, id, enabled)?; + drop(cfg); + state.save()?; + Ok(changed) +} + +#[doc(hidden)] +pub fn set_mcp_enabled_test_hook( + state: &AppState, + app_ty: AppType, + id: &str, + enabled: bool, +) -> Result { + set_mcp_enabled_internal(state, app_ty, id, enabled) +} + +fn import_mcp_from_claude_internal(state: &AppState) -> Result { + let mut cfg = state.config.lock()?; let changed = mcp::import_from_claude(&mut cfg)?; drop(cfg); if changed > 0 { @@ -219,13 +243,13 @@ pub async fn import_mcp_from_claude(state: State<'_, AppState>) -> Result) -> Result { - let mut cfg = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; +#[doc(hidden)] +pub fn import_mcp_from_claude_test_hook(state: &AppState) -> Result { + import_mcp_from_claude_internal(state) +} + +fn import_mcp_from_codex_internal(state: &AppState) -> Result { + let mut cfg = state.config.lock()?; let changed = mcp::import_from_codex(&mut cfg)?; drop(cfg); if changed > 0 { @@ -233,3 +257,8 @@ pub async fn import_mcp_from_codex(state: State<'_, AppState>) -> Result Result { + import_mcp_from_codex_internal(state) +} diff --git a/src-tauri/src/commands/provider.rs b/src-tauri/src/commands/provider.rs index 91a3a47..3051ba7 100644 --- a/src-tauri/src/commands/provider.rs +++ b/src-tauri/src/commands/provider.rs @@ -271,41 +271,14 @@ pub async fn delete_provider( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - let mut config = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - if manager.current == id { - return Err("不能删除当前正在使用的供应商".to_string()); + { + let mut config = state + .config + .lock() + .map_err(|e| format!("获取锁失败: {}", e))?; + ProviderService::delete(&mut config, app_type, &id).map_err(|e| e.to_string())?; } - let provider = manager - .providers - .get(&id) - .ok_or_else(|| format!("供应商不存在: {}", id))? - .clone(); - - match app_type { - AppType::Codex => { - codex_config::delete_codex_provider_config(&id, &provider.name)?; - } - AppType::Claude => { - use crate::config::{delete_file, get_provider_config_path}; - let by_name = get_provider_config_path(&id, Some(&provider.name)); - let by_id = get_provider_config_path(&id, None); - delete_file(&by_name)?; - delete_file(&by_id)?; - } - } - - manager.providers.remove(&id); - - drop(config); state.save()?; Ok(true) @@ -313,10 +286,7 @@ pub async fn delete_provider( /// 切换供应商 fn switch_provider_internal(state: &AppState, app_type: AppType, id: &str) -> Result<(), AppError> { - let mut config = state - .config - .lock() - .map_err(AppError::from)?; + let mut config = state.config.lock().map_err(AppError::from)?; ProviderService::switch(&mut config, app_type, id)?; @@ -351,6 +321,65 @@ pub async fn switch_provider( .map_err(|e| e.to_string()) } +fn import_default_config_internal(state: &AppState, app_type: AppType) -> Result<(), AppError> { + { + let config = state.config.lock()?; + if let Some(manager) = config.get_manager(&app_type) { + if !manager.get_all_providers().is_empty() { + // 已存在供应商则视为已导入,保持与原逻辑一致 + return Ok(()); + } + } + } + + let settings_config = match app_type { + AppType::Codex => { + let auth_path = codex_config::get_codex_auth_path(); + if !auth_path.exists() { + return Err(AppError::Message("Codex 配置文件不存在".to_string())); + } + let auth: serde_json::Value = crate::config::read_json_file(&auth_path)?; + let config_str = crate::codex_config::read_and_validate_codex_config_text()?; + serde_json::json!({ "auth": auth, "config": config_str }) + } + AppType::Claude => { + let settings_path = get_claude_settings_path(); + if !settings_path.exists() { + return Err(AppError::Message("Claude Code 配置文件不存在".to_string())); + } + crate::config::read_json_file(&settings_path)? + } + }; + + let provider = Provider::with_id( + "default".to_string(), + "default".to_string(), + settings_config, + None, + ); + + let mut config = state.config.lock()?; + let manager = config + .get_manager_mut(&app_type) + .ok_or_else(|| AppError::Message(format!("应用类型不存在: {:?}", app_type)))?; + + manager.providers.insert(provider.id.clone(), provider); + manager.current = "default".to_string(); + + drop(config); + state.save()?; + + Ok(()) +} + +#[doc(hidden)] +pub fn import_default_config_test_hook( + state: &AppState, + app_type: AppType, +) -> Result<(), AppError> { + import_default_config_internal(state, app_type) +} + /// 导入当前配置为默认供应商 #[tauri::command] pub async fn import_default_config( @@ -364,62 +393,9 @@ pub async fn import_default_config( .or_else(|| appType.as_deref().map(|s| s.into())) .unwrap_or(AppType::Claude); - { - let config = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; - - if let Some(manager) = config.get_manager(&app_type) { - if !manager.get_all_providers().is_empty() { - return Ok(true); - } - } - } - - let settings_config = match app_type { - AppType::Codex => { - let auth_path = codex_config::get_codex_auth_path(); - if !auth_path.exists() { - return Err("Codex 配置文件不存在".to_string()); - } - let auth: serde_json::Value = - crate::config::read_json_file::(&auth_path)?; - let config_str = crate::codex_config::read_and_validate_codex_config_text()?; - serde_json::json!({ "auth": auth, "config": config_str }) - } - AppType::Claude => { - let settings_path = get_claude_settings_path(); - if !settings_path.exists() { - return Err("Claude Code 配置文件不存在".to_string()); - } - crate::config::read_json_file::(&settings_path)? - } - }; - - let provider = Provider::with_id( - "default".to_string(), - "default".to_string(), - settings_config, - None, - ); - - let mut config = state - .config - .lock() - .map_err(|e| format!("获取锁失败: {}", e))?; - - let manager = config - .get_manager_mut(&app_type) - .ok_or_else(|| format!("应用类型不存在: {:?}", app_type))?; - - manager.providers.insert(provider.id.clone(), provider); - manager.current = "default".to_string(); - - drop(config); - state.save()?; - - Ok(true) + import_default_config_internal(&*state, app_type) + .map(|_| true) + .map_err(Into::into) } /// 查询供应商用量 diff --git a/src-tauri/src/commands/settings.rs b/src-tauri/src/commands/settings.rs index a569b8c..7f7292e 100644 --- a/src-tauri/src/commands/settings.rs +++ b/src-tauri/src/commands/settings.rs @@ -11,8 +11,7 @@ pub async fn get_settings() -> Result { /// 保存设置 #[tauri::command] pub async fn save_settings(settings: crate::settings::AppSettings) -> Result { - crate::settings::update_settings(settings) - .map_err(|e| e.to_string())?; + crate::settings::update_settings(settings).map_err(|e| e.to_string())?; Ok(true) } diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index e8ce235..99192bb 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -150,10 +150,7 @@ pub fn get_provider_config_path(provider_id: &str, provider_name: Option<&str>) /// 读取 JSON 配置文件 pub fn read_json_file Deserialize<'a>>(path: &Path) -> Result { if !path.exists() { - return Err(AppError::Config(format!( - "文件不存在: {}", - path.display() - ))); + return Err(AppError::Config(format!("文件不存在: {}", path.display()))); } let content = fs::read_to_string(path).map_err(|e| AppError::io(path, e))?; @@ -168,8 +165,8 @@ pub fn write_json_file(path: &Path, data: &T) -> Result<(), AppErr fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?; } - let json = serde_json::to_string_pretty(data) - .map_err(|e| AppError::JsonSerialize { source: e })?; + let json = + serde_json::to_string_pretty(data).map_err(|e| AppError::JsonSerialize { source: e })?; atomic_write(path, json.as_bytes()) } @@ -204,8 +201,7 @@ pub fn atomic_write(path: &Path, data: &[u8]) -> Result<(), AppError> { tmp.push(format!("{}.tmp.{}", file_name, ts)); { - let mut f = - fs::File::create(&tmp).map_err(|e| AppError::io(&tmp, e))?; + let mut f = fs::File::create(&tmp).map_err(|e| AppError::io(&tmp, e))?; f.write_all(data).map_err(|e| AppError::io(&tmp, e))?; f.flush().map_err(|e| AppError::io(&tmp, e))?; } @@ -226,11 +222,7 @@ pub fn atomic_write(path: &Path, data: &[u8]) -> Result<(), AppError> { let _ = fs::remove_file(path); } fs::rename(&tmp, path).map_err(|e| AppError::IoContext { - context: format!( - "原子替换失败: {} -> {}", - tmp.display(), - path.display() - ), + context: format!("原子替换失败: {} -> {}", tmp.display(), path.display()), source: e, })?; } @@ -238,11 +230,7 @@ pub fn atomic_write(path: &Path, data: &[u8]) -> Result<(), AppError> { #[cfg(not(windows))] { fs::rename(&tmp, path).map_err(|e| AppError::IoContext { - context: format!( - "原子替换失败: {} -> {}", - tmp.display(), - path.display() - ), + context: format!("原子替换失败: {} -> {}", tmp.display(), path.display()), source: e, })?; } @@ -287,11 +275,7 @@ mod tests { /// 复制文件 pub fn copy_file(from: &Path, to: &Path) -> Result<(), AppError> { fs::copy(from, to).map_err(|e| AppError::IoContext { - context: format!( - "复制文件失败 ({} -> {})", - from.display(), - to.display() - ), + context: format!("复制文件失败 ({} -> {})", from.display(), to.display()), source: e, })?; Ok(()) diff --git a/src-tauri/src/import_export.rs b/src-tauri/src/import_export.rs index e50da4f..80e0272 100644 --- a/src-tauri/src/import_export.rs +++ b/src-tauri/src/import_export.rs @@ -131,23 +131,15 @@ fn sync_codex_live( ) -> Result<(), AppError> { use serde_json::Value; - let settings = provider - .settings_config - .as_object() - .ok_or_else(|| { - AppError::Config(format!( - "供应商 {} 的 Codex 配置必须是对象", - provider_id - )) - })?; - let auth = settings - .get("auth") - .ok_or_else(|| { - AppError::Config(format!( - "供应商 {} 的 Codex 配置缺少 auth 字段", - provider_id - )) - })?; + let settings = provider.settings_config.as_object().ok_or_else(|| { + AppError::Config(format!("供应商 {} 的 Codex 配置必须是对象", provider_id)) + })?; + let auth = settings.get("auth").ok_or_else(|| { + AppError::Config(format!( + "供应商 {} 的 Codex 配置缺少 auth 字段", + provider_id + )) + })?; if !auth.is_object() { return Err(AppError::Config(format!( "供应商 {} 的 Codex auth 配置必须是 JSON 对象", @@ -203,12 +195,11 @@ fn sync_claude_live( pub async fn export_config_to_file(file_path: String) -> Result { // 读取当前配置文件 let config_path = crate::config::get_app_config_path(); - let config_content = fs::read_to_string(&config_path) - .map_err(|e| AppError::io(&config_path, e).to_string())?; + let config_content = + fs::read_to_string(&config_path).map_err(|e| AppError::io(&config_path, e).to_string())?; // 写入到指定文件 - fs::write(&file_path, &config_content) - .map_err(|e| AppError::io(&file_path, e).to_string())?; + fs::write(&file_path, &config_content).map_err(|e| AppError::io(&file_path, e).to_string())?; Ok(json!({ "success": true, @@ -239,18 +230,15 @@ pub fn import_config_from_path( file_path: &Path, state: &crate::store::AppState, ) -> Result { - let import_content = - fs::read_to_string(file_path).map_err(|e| AppError::io(file_path, e))?; + let import_content = fs::read_to_string(file_path).map_err(|e| AppError::io(file_path, e))?; let new_config: crate::app_config::MultiAppConfig = - serde_json::from_str(&import_content) - .map_err(|e| AppError::json(file_path, e))?; + serde_json::from_str(&import_content).map_err(|e| AppError::json(file_path, e))?; let config_path = crate::config::get_app_config_path(); let backup_id = create_backup(&config_path)?; - fs::write(&config_path, &import_content) - .map_err(|e| AppError::io(&config_path, e))?; + fs::write(&config_path, &import_content).map_err(|e| AppError::io(&config_path, e))?; { let mut guard = state.config.lock().map_err(AppError::from)?; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index b8c9c1d..de5c2a3 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -10,25 +10,27 @@ mod import_export; mod mcp; mod migration; mod provider; +mod services; mod settings; mod speedtest; mod store; -mod services; mod usage_script; pub use app_config::{AppType, MultiAppConfig}; pub use codex_config::{get_codex_auth_path, get_codex_config_path, write_codex_live_atomic}; +pub use commands::*; pub use config::{get_claude_mcp_path, get_claude_settings_path, read_json_file}; +pub use error::AppError; pub use import_export::{ create_backup, export_config_to_file, import_config_from_path, sync_current_providers_to_live, }; +pub use mcp::{ + import_from_claude, import_from_codex, sync_enabled_to_claude, sync_enabled_to_codex, +}; pub use provider::Provider; -pub use settings::{update_settings, AppSettings}; -pub use mcp::{import_from_claude, import_from_codex, sync_enabled_to_claude, sync_enabled_to_codex}; -pub use error::AppError; -pub use store::AppState; pub use services::ProviderService; -pub use commands::*; +pub use settings::{update_settings, AppSettings}; +pub use store::AppState; use tauri::{ menu::{CheckMenuItem, Menu, MenuBuilder, MenuItem}, @@ -43,10 +45,7 @@ fn create_tray_menu( app: &tauri::AppHandle, app_state: &AppState, ) -> Result, AppError> { - let config = app_state - .config - .lock() - .map_err(AppError::from)?; + let config = app_state.config.lock().map_err(AppError::from)?; let mut menu_builder = MenuBuilder::new(app); diff --git a/src-tauri/src/mcp.rs b/src-tauri/src/mcp.rs index c563b08..5d28774 100644 --- a/src-tauri/src/mcp.rs +++ b/src-tauri/src/mcp.rs @@ -44,15 +44,11 @@ fn validate_server_spec(spec: &Value) -> Result<(), AppError> { fn validate_mcp_entry(entry: &Value) -> Result<(), AppError> { let obj = entry .as_object() - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()))?; let server = obj .get("server") - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器条目缺少 server 字段".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器条目缺少 server 字段".into()))?; validate_server_spec(server)?; for key in ["name", "description", "homepage", "docs"] { @@ -67,9 +63,9 @@ fn validate_mcp_entry(entry: &Value) -> Result<(), AppError> { } if let Some(tags) = obj.get("tags") { - let arr = tags.as_array().ok_or_else(|| { - AppError::McpValidation("MCP 服务器 tags 必须为字符串数组".into()) - })?; + let arr = tags + .as_array() + .ok_or_else(|| AppError::McpValidation("MCP 服务器 tags 必须为字符串数组".into()))?; if !arr.iter().all(|item| item.is_string()) { return Err(AppError::McpValidation( "MCP 服务器 tags 必须为字符串数组".into(), @@ -182,14 +178,10 @@ pub fn normalize_servers_for(config: &mut MultiAppConfig, app: &AppType) -> usiz fn extract_server_spec(entry: &Value) -> Result { let obj = entry .as_object() - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()))?; let server = obj .get("server") - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器条目缺少 server 字段".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器条目缺少 server 字段".into()))?; if !server.is_object() { return Err(AppError::McpValidation( @@ -255,9 +247,7 @@ pub fn upsert_in_config_for( spec: Value, ) -> Result { if id.trim().is_empty() { - return Err(AppError::InvalidInput( - "MCP 服务器 ID 不能为空".into(), - )); + return Err(AppError::InvalidInput("MCP 服务器 ID 不能为空".into())); } normalize_servers_for(config, app); validate_mcp_entry(&spec)?; @@ -265,14 +255,10 @@ pub fn upsert_in_config_for( let mut entry_obj = spec .as_object() .cloned() - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器条目必须为 JSON 对象".into()))?; if let Some(existing_id) = entry_obj.get("id") { let Some(existing_id_str) = existing_id.as_str() else { - return Err(AppError::McpValidation( - "MCP 服务器 id 必须为字符串".into(), - )); + return Err(AppError::McpValidation("MCP 服务器 id 必须为字符串".into())); }; if existing_id_str != id { return Err(AppError::McpValidation(format!( @@ -299,9 +285,7 @@ pub fn delete_in_config_for( id: &str, ) -> Result { if id.trim().is_empty() { - return Err(AppError::InvalidInput( - "MCP 服务器 ID 不能为空".into(), - )); + return Err(AppError::InvalidInput("MCP 服务器 ID 不能为空".into())); } normalize_servers_for(config, app); let existed = config.mcp_for_mut(app).servers.remove(id).is_some(); @@ -316,9 +300,7 @@ pub fn set_enabled_and_sync_for( enabled: bool, ) -> Result { if id.trim().is_empty() { - return Err(AppError::InvalidInput( - "MCP 服务器 ID 不能为空".into(), - )); + return Err(AppError::InvalidInput("MCP 服务器 ID 不能为空".into())); } normalize_servers_for(config, app); if let Some(spec) = config.mcp_for_mut(app).servers.get_mut(id) { @@ -326,9 +308,7 @@ pub fn set_enabled_and_sync_for( let mut obj = spec .as_object() .cloned() - .ok_or_else(|| { - AppError::McpValidation("MCP 服务器定义必须为 JSON 对象".into()) - })?; + .ok_or_else(|| AppError::McpValidation("MCP 服务器定义必须为 JSON 对象".into()))?; obj.insert("enabled".into(), json!(enabled)); *spec = Value::Object(obj); } else { @@ -362,9 +342,8 @@ pub fn import_from_claude(config: &mut MultiAppConfig) -> Result Result } let mut changed_total = normalize_servers_for(config, &AppType::Codex); - let root: toml::Table = toml::from_str(&text).map_err(|e| { - AppError::McpValidation(format!( - "解析 ~/.codex/config.toml 失败: {}", - e - )) - })?; + let root: toml::Table = toml::from_str(&text) + .map_err(|e| AppError::McpValidation(format!("解析 ~/.codex/config.toml 失败: {}", e)))?; // helper:处理一组 servers 表 let mut import_servers_tbl = |servers_tbl: &toml::value::Table| { @@ -619,9 +594,8 @@ pub fn sync_enabled_to_codex(config: &MultiAppConfig) -> Result<(), AppError> { let mut root: TomlTable = if base_text.trim().is_empty() { TomlTable::new() } else { - toml::from_str::(&base_text).map_err(|e| { - AppError::McpValidation(format!("解析 config.toml 失败: {}", e)) - })? + toml::from_str::(&base_text) + .map_err(|e| AppError::McpValidation(format!("解析 config.toml 失败: {}", e)))? }; // 3) 写入 servers 表(支持 mcp.servers 与 mcp_servers;优先沿用已有风格,默认 mcp_servers) @@ -767,9 +741,8 @@ pub fn sync_enabled_to_codex(config: &MultiAppConfig) -> Result<(), AppError> { } // 4) 序列化并写回 config.toml(仅改 TOML,不触碰 auth.json) - let new_text = toml::to_string(&TomlValue::Table(root)).map_err(|e| { - AppError::McpValidation(format!("序列化 config.toml 失败: {}", e)) - })?; + let new_text = toml::to_string(&TomlValue::Table(root)) + .map_err(|e| AppError::McpValidation(format!("序列化 config.toml 失败: {}", e)))?; let path = crate::codex_config::get_codex_config_path(); crate::config::write_text_file(&path, &new_text)?; diff --git a/src-tauri/src/migration.rs b/src-tauri/src/migration.rs index b281cde..f253e2e 100644 --- a/src-tauri/src/migration.rs +++ b/src-tauri/src/migration.rs @@ -149,8 +149,7 @@ pub fn migrate_copies_into_config(config: &mut MultiAppConfig) -> Result Result<(), AppError> { + let current_matches = config + .get_manager(&app_type) + .map(|m| m.current == provider_id) + .unwrap_or(false); + if current_matches { + return Err(AppError::Config("不能删除当前正在使用的供应商".into())); + } + + let provider = config + .get_manager(&app_type) + .ok_or_else(|| AppError::Message(format!("应用类型不存在: {:?}", app_type)))? + .providers + .get(provider_id) + .cloned() + .ok_or_else(|| AppError::ProviderNotFound(provider_id.to_string()))?; + + match app_type { + AppType::Codex => { + crate::codex_config::delete_codex_provider_config(provider_id, &provider.name)?; + } + AppType::Claude => { + let by_name = get_provider_config_path(provider_id, Some(&provider.name)); + let by_id = get_provider_config_path(provider_id, None); + delete_file(&by_name)?; + delete_file(&by_id)?; + } + } + + if let Some(manager) = config.get_manager_mut(&app_type) { + manager.providers.remove(provider_id); + } + + Ok(()) + } } diff --git a/src-tauri/src/usage_script.rs b/src-tauri/src/usage_script.rs index ed9e34b..29e041d 100644 --- a/src-tauri/src/usage_script.rs +++ b/src-tauri/src/usage_script.rs @@ -1,5 +1,5 @@ use reqwest::Client; -use rquickjs::{Context, Runtime, Function}; +use rquickjs::{Context, Function, Runtime}; use serde_json::Value; use std::collections::HashMap; use std::time::Duration; @@ -20,8 +20,8 @@ pub async fn execute_usage_script( // 2. 在独立作用域中提取 request 配置(确保 Runtime/Context 在 await 前释放) let request_config = { - let runtime = Runtime::new() - .map_err(|e| AppError::Message(format!("创建 JS 运行时失败: {}", e)))?; + let runtime = + Runtime::new().map_err(|e| AppError::Message(format!("创建 JS 运行时失败: {}", e)))?; let context = Context::full(&runtime) .map_err(|e| AppError::Message(format!("创建 JS 上下文失败: {}", e)))?; @@ -57,8 +57,8 @@ pub async fn execute_usage_script( // 5. 在独立作用域中执行 extractor(确保 Runtime/Context 在函数结束前释放) let result: Value = { - let runtime = Runtime::new() - .map_err(|e| AppError::Message(format!("创建 JS 运行时失败: {}", e)))?; + let runtime = + Runtime::new().map_err(|e| AppError::Message(format!("创建 JS 运行时失败: {}", e)))?; let context = Context::full(&runtime) .map_err(|e| AppError::Message(format!("创建 JS 上下文失败: {}", e)))?; @@ -121,10 +121,7 @@ async fn send_http_request(config: &RequestConfig, timeout_secs: u64) -> Result< .build() .map_err(|e| AppError::Message(format!("创建客户端失败: {}", e)))?; - let method = config - .method - .parse() - .unwrap_or(reqwest::Method::GET); + let method = config.method.parse().unwrap_or(reqwest::Method::GET); let mut req = client.request(method.clone(), &config.url); @@ -171,9 +168,7 @@ fn validate_result(result: &Value) -> Result<(), AppError> { } for (idx, item) in arr.iter().enumerate() { validate_single_usage(item) - .map_err(|e| { - AppError::InvalidInput(format!("数组索引[{}]验证失败: {}", idx, e)) - })?; + .map_err(|e| AppError::InvalidInput(format!("数组索引[{}]验证失败: {}", idx, e)))?; } return Ok(()); } @@ -184,18 +179,16 @@ fn validate_result(result: &Value) -> Result<(), AppError> { /// 验证单个用量数据对象 fn validate_single_usage(result: &Value) -> Result<(), AppError> { - let obj = result.as_object().ok_or_else(|| { - AppError::InvalidInput("脚本必须返回对象或对象数组".into()) - })?; + let obj = result + .as_object() + .ok_or_else(|| AppError::InvalidInput("脚本必须返回对象或对象数组".into()))?; // 所有字段均为可选,只进行类型检查 if obj.contains_key("isValid") && !result["isValid"].is_null() && !result["isValid"].is_boolean() { - return Err(AppError::InvalidInput( - "isValid 必须是布尔值或 null".into(), - )); + return Err(AppError::InvalidInput("isValid 必须是布尔值或 null".into())); } if obj.contains_key("invalidMessage") && !result["invalidMessage"].is_null() @@ -209,33 +202,16 @@ fn validate_single_usage(result: &Value) -> Result<(), AppError> { && !result["remaining"].is_null() && !result["remaining"].is_number() { - return Err(AppError::InvalidInput( - "remaining 必须是数字或 null".into(), - )); + return Err(AppError::InvalidInput("remaining 必须是数字或 null".into())); } - if obj.contains_key("unit") - && !result["unit"].is_null() - && !result["unit"].is_string() - { - return Err(AppError::InvalidInput( - "unit 必须是字符串或 null".into(), - )); + if obj.contains_key("unit") && !result["unit"].is_null() && !result["unit"].is_string() { + return Err(AppError::InvalidInput("unit 必须是字符串或 null".into())); } - if obj.contains_key("total") - && !result["total"].is_null() - && !result["total"].is_number() - { - return Err(AppError::InvalidInput( - "total 必须是数字或 null".into(), - )); + if obj.contains_key("total") && !result["total"].is_null() && !result["total"].is_number() { + return Err(AppError::InvalidInput("total 必须是数字或 null".into())); } - if obj.contains_key("used") - && !result["used"].is_null() - && !result["used"].is_number() - { - return Err(AppError::InvalidInput( - "used 必须是数字或 null".into(), - )); + if obj.contains_key("used") && !result["used"].is_null() && !result["used"].is_number() { + return Err(AppError::InvalidInput("used 必须是数字或 null".into())); } if obj.contains_key("planName") && !result["planName"].is_null() @@ -245,13 +221,8 @@ fn validate_single_usage(result: &Value) -> Result<(), AppError> { "planName 必须是字符串或 null".into(), )); } - if obj.contains_key("extra") - && !result["extra"].is_null() - && !result["extra"].is_string() - { - return Err(AppError::InvalidInput( - "extra 必须是字符串或 null".into(), - )); + if obj.contains_key("extra") && !result["extra"].is_null() && !result["extra"].is_string() { + return Err(AppError::InvalidInput("extra 必须是字符串或 null".into())); } Ok(()) diff --git a/src-tauri/tests/import_export_sync.rs b/src-tauri/tests/import_export_sync.rs index 1b4f2e9..c6acbbe 100644 --- a/src-tauri/tests/import_export_sync.rs +++ b/src-tauri/tests/import_export_sync.rs @@ -1,6 +1,6 @@ +use serde_json::json; use std::{fs, path::Path, sync::Mutex}; use tauri::async_runtime; -use serde_json::json; use cc_switch_lib::{ create_backup, get_claude_settings_path, import_config_from_path, read_json_file, @@ -77,22 +77,18 @@ fn sync_codex_provider_writes_auth_and_config() { let mut config = MultiAppConfig::default(); // 添加入测 MCP 启用项,确保 sync_enabled_to_codex 会写入 TOML - config - .mcp - .codex - .servers - .insert( - "echo-server".into(), - json!({ - "id": "echo-server", - "enabled": true, - "server": { - "type": "stdio", - "command": "echo", - "args": ["hello"] - } - }), - ); + config.mcp.codex.servers.insert( + "echo-server".into(), + json!({ + "id": "echo-server", + "enabled": true, + "server": { + "type": "stdio", + "command": "echo", + "args": ["hello"] + } + }), + ); let provider_config = json!({ "auth": { @@ -143,9 +139,7 @@ fn sync_codex_provider_writes_auth_and_config() { ); // 当前供应商应同步最新 config 文本 - let manager = config - .get_manager(&AppType::Codex) - .expect("codex manager"); + let manager = config.get_manager(&AppType::Codex).expect("codex manager"); let synced = manager.providers.get("codex-1").expect("codex provider"); let synced_cfg = synced .settings_config @@ -161,22 +155,18 @@ fn sync_enabled_to_codex_writes_enabled_servers() { reset_test_fs(); let mut config = MultiAppConfig::default(); - config - .mcp - .codex - .servers - .insert( - "stdio-enabled".into(), - json!({ - "id": "stdio-enabled", - "enabled": true, - "server": { - "type": "stdio", - "command": "echo", - "args": ["ok"], - } - }), - ); + config.mcp.codex.servers.insert( + "stdio-enabled".into(), + json!({ + "id": "stdio-enabled", + "enabled": true, + "server": { + "type": "stdio", + "command": "echo", + "args": ["ok"], + } + }), + ); cc_switch_lib::sync_enabled_to_codex(&config).expect("sync codex"); @@ -241,10 +231,16 @@ fn sync_enabled_to_codex_returns_error_on_invalid_toml() { let err = cc_switch_lib::sync_enabled_to_codex(&config).expect_err("sync should fail"); match err { cc_switch_lib::AppError::Toml { path, .. } => { - assert!(path.ends_with("config.toml"), "path should reference config.toml"); + assert!( + path.ends_with("config.toml"), + "path should reference config.toml" + ); } cc_switch_lib::AppError::McpValidation(msg) => { - assert!(msg.contains("config.toml"), "error message should mention config.toml"); + assert!( + msg.contains("config.toml"), + "error message should mention config.toml" + ); } other => panic!("unexpected error: {other:?}"), } @@ -400,16 +396,31 @@ url = "https://example.com" assert!(changed >= 2, "should import both servers"); let servers = &config.mcp.codex.servers; - let echo = servers.get("echo_server").and_then(|v| v.as_object()).expect("echo server"); + let echo = servers + .get("echo_server") + .and_then(|v| v.as_object()) + .expect("echo server"); assert_eq!(echo.get("enabled").and_then(|v| v.as_bool()), Some(true)); - let server_spec = echo.get("server").and_then(|v| v.as_object()).expect("server spec"); + let server_spec = echo + .get("server") + .and_then(|v| v.as_object()) + .expect("server spec"); assert_eq!( - server_spec.get("command").and_then(|v| v.as_str()).unwrap_or(""), + server_spec + .get("command") + .and_then(|v| v.as_str()) + .unwrap_or(""), "echo" ); - let http = servers.get("http_server").and_then(|v| v.as_object()).expect("http server"); - let http_spec = http.get("server").and_then(|v| v.as_object()).expect("http spec"); + let http = servers + .get("http_server") + .and_then(|v| v.as_object()) + .expect("http server"); + let http_spec = http + .get("server") + .and_then(|v| v.as_object()) + .expect("http spec"); assert_eq!( http_spec.get("url").and_then(|v| v.as_str()).unwrap_or(""), "https://example.com" @@ -434,22 +445,18 @@ command = "echo" .expect("write codex config"); let mut config = MultiAppConfig::default(); - config - .mcp - .codex - .servers - .insert( - "existing".into(), - json!({ - "id": "existing", - "name": "existing", - "enabled": false, - "server": { - "type": "stdio", - "command": "prev" - } - }), - ); + config.mcp.codex.servers.insert( + "existing".into(), + json!({ + "id": "existing", + "name": "existing", + "enabled": false, + "server": { + "type": "stdio", + "command": "prev" + } + }), + ); let changed = cc_switch_lib::import_from_codex(&mut config).expect("import codex"); assert!(changed >= 1, "should mark change for enabled flag"); @@ -462,7 +469,10 @@ command = "echo" .and_then(|v| v.as_object()) .expect("existing entry"); assert_eq!(entry.get("enabled").and_then(|v| v.as_bool()), Some(true)); - let spec = entry.get("server").and_then(|v| v.as_object()).expect("server spec"); + let spec = entry + .get("server") + .and_then(|v| v.as_object()) + .expect("server spec"); // 保留原 command,确保导入不会覆盖现有 server 细节 assert_eq!(spec.get("command").and_then(|v| v.as_str()), Some("prev")); } @@ -542,11 +552,9 @@ fn import_from_claude_merges_into_config() { .expect("write claude json"); let mut config = MultiAppConfig::default(); - config - .mcp - .claude - .servers - .insert("stdio-enabled".into(), json!({ + config.mcp.claude.servers.insert( + "stdio-enabled".into(), + json!({ "id": "stdio-enabled", "name": "stdio-enabled", "enabled": false, @@ -554,7 +562,8 @@ fn import_from_claude_merges_into_config() { "type": "stdio", "command": "prev" } - })); + }), + ); let changed = cc_switch_lib::import_from_claude(&mut config).expect("import from claude"); assert!(changed >= 1, "should mark at least one change"); @@ -567,7 +576,10 @@ fn import_from_claude_merges_into_config() { .and_then(|v| v.as_object()) .expect("entry exists"); assert_eq!(entry.get("enabled").and_then(|v| v.as_bool()), Some(true)); - let server = entry.get("server").and_then(|v| v.as_object()).expect("server obj"); + let server = entry + .get("server") + .and_then(|v| v.as_object()) + .expect("server obj"); assert_eq!( server.get("command").and_then(|v| v.as_str()).unwrap_or(""), "prev", @@ -745,10 +757,7 @@ fn import_config_from_path_overwrites_state_and_creates_backup() { "saved config should record new current provider" ); - let guard = app_state - .config - .lock() - .expect("lock state after import"); + let guard = app_state.config.lock().expect("lock state after import"); let claude_manager = guard .get_manager(&AppType::Claude) .expect("claude manager in state"); @@ -778,8 +787,7 @@ fn import_config_from_path_invalid_json_returns_error() { config: Mutex::new(MultiAppConfig::default()), }; - let err = - import_config_from_path(&invalid_path, &app_state).expect_err("import should fail"); + let err = import_config_from_path(&invalid_path, &app_state).expect_err("import should fail"); match err { AppError::Json { .. } => {} other => panic!("expected json error, got {other:?}"), @@ -825,10 +833,7 @@ fn export_config_to_file_writes_target_path() { export_path.to_string_lossy().to_string(), )) .expect("export should succeed"); - assert_eq!( - result.get("success").and_then(|v| v.as_bool()), - Some(true) - ); + assert_eq!(result.get("success").and_then(|v| v.as_bool()), Some(true)); let exported = fs::read_to_string(&export_path).expect("read exported file"); assert!( diff --git a/src-tauri/tests/mcp_commands.rs b/src-tauri/tests/mcp_commands.rs new file mode 100644 index 0000000..945bc69 --- /dev/null +++ b/src-tauri/tests/mcp_commands.rs @@ -0,0 +1,232 @@ +use std::fs; + +use serde_json::json; + +use cc_switch_lib::{ + get_claude_mcp_path, get_claude_settings_path, import_default_config_test_hook, + import_mcp_from_claude_test_hook, set_mcp_enabled_test_hook, AppError, AppState, AppType, + MultiAppConfig, +}; + +#[path = "support.rs"] +mod support; +use support::{ensure_test_home, reset_test_fs, test_mutex}; + +#[test] +fn import_default_config_claude_persists_provider() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let settings_path = get_claude_settings_path(); + if let Some(parent) = settings_path.parent() { + fs::create_dir_all(parent).expect("create claude settings dir"); + } + let settings = json!({ + "env": { + "ANTHROPIC_AUTH_TOKEN": "test-key", + "ANTHROPIC_BASE_URL": "https://api.test" + } + }); + fs::write( + &settings_path, + serde_json::to_string_pretty(&settings).expect("serialize settings"), + ) + .expect("seed claude settings.json"); + + let mut config = MultiAppConfig::default(); + config.ensure_app(&AppType::Claude); + let state = AppState { + config: std::sync::Mutex::new(config), + }; + + import_default_config_test_hook(&state, AppType::Claude) + .expect("import default config succeeds"); + + // 验证内存状态 + let guard = state.config.lock().expect("lock config"); + let manager = guard + .get_manager(&AppType::Claude) + .expect("claude manager present"); + assert_eq!(manager.current, "default"); + let default_provider = manager.providers.get("default").expect("default provider"); + assert_eq!( + default_provider.settings_config, settings, + "default provider should capture live settings" + ); + drop(guard); + + // 验证配置已持久化 + let config_path = home.join(".cc-switch").join("config.json"); + assert!( + config_path.exists(), + "importing default config should persist config.json" + ); +} + +#[test] +fn import_default_config_without_live_file_returns_error() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let state = AppState { + config: std::sync::Mutex::new(MultiAppConfig::default()), + }; + + let err = import_default_config_test_hook(&state, AppType::Claude) + .expect_err("missing live file should error"); + match err { + AppError::Message(msg) => assert!( + msg.contains("Claude Code 配置文件不存在"), + "unexpected error message: {msg}" + ), + other => panic!("unexpected error variant: {other:?}"), + } + + let config_path = home.join(".cc-switch").join("config.json"); + assert!( + !config_path.exists(), + "failed import should not create config.json" + ); +} + +#[test] +fn import_mcp_from_claude_creates_config_and_enables_servers() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let mcp_path = get_claude_mcp_path(); + let claude_json = json!({ + "mcpServers": { + "echo": { + "type": "stdio", + "command": "echo" + } + } + }); + fs::write( + &mcp_path, + serde_json::to_string_pretty(&claude_json).expect("serialize claude mcp"), + ) + .expect("seed ~/.claude.json"); + + let state = AppState { + config: std::sync::Mutex::new(MultiAppConfig::default()), + }; + + let changed = + import_mcp_from_claude_test_hook(&state).expect("import mcp from claude succeeds"); + assert!( + changed > 0, + "import should report inserted or normalized entries" + ); + + let guard = state.config.lock().expect("lock config"); + let claude_servers = &guard.mcp.claude.servers; + let entry = claude_servers + .get("echo") + .expect("server imported into config.json"); + assert!( + entry + .get("enabled") + .and_then(|v| v.as_bool()) + .unwrap_or(false), + "imported server should be marked enabled" + ); + drop(guard); + + let config_path = home.join(".cc-switch").join("config.json"); + assert!( + config_path.exists(), + "state.save should persist config.json when changes detected" + ); +} + +#[test] +fn import_mcp_from_claude_invalid_json_preserves_state() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let mcp_path = get_claude_mcp_path(); + fs::write(&mcp_path, "{\"mcpServers\":") // 不完整 JSON + .expect("seed invalid ~/.claude.json"); + + let state = AppState { + config: std::sync::Mutex::new(MultiAppConfig::default()), + }; + + let err = + import_mcp_from_claude_test_hook(&state).expect_err("invalid json should bubble up error"); + match err { + AppError::McpValidation(msg) => assert!( + msg.contains("解析 ~/.claude.json 失败"), + "unexpected error message: {msg}" + ), + other => panic!("unexpected error variant: {other:?}"), + } + + let config_path = home.join(".cc-switch").join("config.json"); + assert!( + !config_path.exists(), + "failed import should not persist config.json" + ); +} + +#[test] +fn set_mcp_enabled_for_codex_writes_live_config() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + ensure_test_home(); + + let mut config = MultiAppConfig::default(); + config.ensure_app(&AppType::Codex); + config.mcp.codex.servers.insert( + "codex-server".into(), + json!({ + "id": "codex-server", + "name": "Codex Server", + "server": { + "type": "stdio", + "command": "echo" + }, + "enabled": false + }), + ); + + let state = AppState { + config: std::sync::Mutex::new(config), + }; + + set_mcp_enabled_test_hook(&state, AppType::Codex, "codex-server", true) + .expect("set enabled should succeed"); + + let guard = state.config.lock().expect("lock config"); + let entry = guard + .mcp + .codex + .servers + .get("codex-server") + .expect("codex server exists"); + assert!( + entry + .get("enabled") + .and_then(|v| v.as_bool()) + .unwrap_or(false), + "server should be marked enabled after command" + ); + drop(guard); + + let toml_path = cc_switch_lib::get_codex_config_path(); + assert!( + toml_path.exists(), + "enabling server should trigger sync to ~/.codex/config.toml" + ); + let toml_text = fs::read_to_string(&toml_path).expect("read codex config"); + assert!( + toml_text.contains("codex-server"), + "codex config should include the enabled server definition" + ); +} diff --git a/src-tauri/tests/provider_commands.rs b/src-tauri/tests/provider_commands.rs index f00551e..d0e5fdb 100644 --- a/src-tauri/tests/provider_commands.rs +++ b/src-tauri/tests/provider_commands.rs @@ -88,17 +88,13 @@ command = "say" "live auth.json should reflect new provider" ); - let config_text = - std::fs::read_to_string(get_codex_config_path()).expect("read config.toml"); + let config_text = std::fs::read_to_string(get_codex_config_path()).expect("read config.toml"); assert!( config_text.contains("mcp_servers.echo-server"), "config.toml should contain synced MCP servers" ); - let locked = app_state - .config - .lock() - .expect("lock config after switch"); + let locked = app_state.config.lock().expect("lock config after switch"); let manager = locked .get_manager(&AppType::Codex) .expect("codex manager after switch"); @@ -231,10 +227,7 @@ fn switch_provider_updates_claude_live_and_state() { "live settings.json should reflect new provider auth" ); - let locked = app_state - .config - .lock() - .expect("lock config after switch"); + let locked = app_state.config.lock().expect("lock config after switch"); let manager = locked .get_manager(&AppType::Claude) .expect("claude manager after switch"); @@ -265,8 +258,7 @@ fn switch_provider_updates_claude_live_and_state() { drop(locked); - let home_dir = - std::env::var("HOME").expect("HOME should be set by ensure_test_home"); + let home_dir = std::env::var("HOME").expect("HOME should be set by ensure_test_home"); let config_path = std::path::Path::new(&home_dir) .join(".cc-switch") .join("config.json"); @@ -325,13 +317,8 @@ fn switch_provider_codex_missing_auth_returns_error_and_keeps_state() { other => panic!("expected config error, got {other:?}"), } - let locked = app_state - .config - .lock() - .expect("lock config after failure"); - let manager = locked - .get_manager(&AppType::Codex) - .expect("codex manager"); + let locked = app_state.config.lock().expect("lock config after failure"); + let manager = locked.get_manager(&AppType::Codex).expect("codex manager"); assert!( manager.current.is_empty(), "current provider should remain empty on failure" diff --git a/src-tauri/tests/provider_service.rs b/src-tauri/tests/provider_service.rs new file mode 100644 index 0000000..29a90dd --- /dev/null +++ b/src-tauri/tests/provider_service.rs @@ -0,0 +1,413 @@ +use serde_json::json; + +use cc_switch_lib::{ + get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, + MultiAppConfig, Provider, ProviderService, +}; + +#[path = "support.rs"] +mod support; +use support::{ensure_test_home, reset_test_fs, test_mutex}; + +fn sanitize_provider_name(name: &str) -> String { + name.chars() + .map(|c| match c { + '<' | '>' | ':' | '"' | '/' | '\\' | '|' | '?' | '*' => '-', + _ => c, + }) + .collect::() + .to_lowercase() +} + +#[test] +fn provider_service_switch_codex_updates_live_and_config() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let _home = ensure_test_home(); + + let legacy_auth = json!({ "OPENAI_API_KEY": "legacy-key" }); + let legacy_config = r#"[mcp_servers.legacy] +type = "stdio" +command = "echo" +"#; + write_codex_live_atomic(&legacy_auth, Some(legacy_config)) + .expect("seed existing codex live config"); + + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Codex) + .expect("codex manager"); + manager.current = "old-provider".to_string(); + manager.providers.insert( + "old-provider".to_string(), + Provider::with_id( + "old-provider".to_string(), + "Legacy".to_string(), + json!({ + "auth": {"OPENAI_API_KEY": "stale"}, + "config": "stale-config" + }), + None, + ), + ); + manager.providers.insert( + "new-provider".to_string(), + Provider::with_id( + "new-provider".to_string(), + "Latest".to_string(), + json!({ + "auth": {"OPENAI_API_KEY": "fresh-key"}, + "config": r#"[mcp_servers.latest] +type = "stdio" +command = "say" +"# + }), + None, + ), + ); + } + + config.mcp.codex.servers.insert( + "echo-server".into(), + json!({ + "id": "echo-server", + "enabled": true, + "server": { + "type": "stdio", + "command": "echo" + } + }), + ); + + ProviderService::switch(&mut config, AppType::Codex, "new-provider") + .expect("switch provider should succeed"); + + let auth_value: serde_json::Value = + read_json_file(&cc_switch_lib::get_codex_auth_path()).expect("read auth.json"); + assert_eq!( + auth_value.get("OPENAI_API_KEY").and_then(|v| v.as_str()), + Some("fresh-key"), + "live auth.json should reflect new provider" + ); + + let config_text = + std::fs::read_to_string(cc_switch_lib::get_codex_config_path()).expect("read config.toml"); + assert!( + config_text.contains("mcp_servers.echo-server"), + "config.toml should contain synced MCP servers" + ); + + let manager = config + .get_manager(&AppType::Codex) + .expect("codex manager after switch"); + assert_eq!(manager.current, "new-provider", "current provider updated"); + + let new_provider = manager + .providers + .get("new-provider") + .expect("new provider exists"); + let new_config_text = new_provider + .settings_config + .get("config") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + assert_eq!( + new_config_text, config_text, + "provider config snapshot should match live file" + ); + + let legacy = manager + .providers + .get("old-provider") + .expect("legacy provider still exists"); + let legacy_auth_value = legacy + .settings_config + .get("auth") + .and_then(|v| v.get("OPENAI_API_KEY")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + assert_eq!( + legacy_auth_value, "legacy-key", + "previous provider should be backfilled with live auth" + ); +} + +#[test] +fn provider_service_switch_claude_updates_live_and_state() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let _home = ensure_test_home(); + + let settings_path = get_claude_settings_path(); + if let Some(parent) = settings_path.parent() { + std::fs::create_dir_all(parent).expect("create claude settings dir"); + } + let legacy_live = json!({ + "env": { + "ANTHROPIC_API_KEY": "legacy-key" + }, + "workspace": { + "path": "/tmp/workspace" + } + }); + std::fs::write( + &settings_path, + serde_json::to_string_pretty(&legacy_live).expect("serialize legacy live"), + ) + .expect("seed claude live config"); + + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Claude) + .expect("claude manager"); + manager.current = "old-provider".to_string(); + manager.providers.insert( + "old-provider".to_string(), + Provider::with_id( + "old-provider".to_string(), + "Legacy Claude".to_string(), + json!({ + "env": { "ANTHROPIC_API_KEY": "stale-key" } + }), + None, + ), + ); + manager.providers.insert( + "new-provider".to_string(), + Provider::with_id( + "new-provider".to_string(), + "Fresh Claude".to_string(), + json!({ + "env": { "ANTHROPIC_API_KEY": "fresh-key" }, + "workspace": { "path": "/tmp/new-workspace" } + }), + None, + ), + ); + } + + ProviderService::switch(&mut config, AppType::Claude, "new-provider") + .expect("switch provider should succeed"); + + let live_after: serde_json::Value = + read_json_file(&settings_path).expect("read claude live settings"); + assert_eq!( + live_after + .get("env") + .and_then(|env| env.get("ANTHROPIC_API_KEY")) + .and_then(|key| key.as_str()), + Some("fresh-key"), + "live settings.json should reflect new provider auth" + ); + + let manager = config + .get_manager(&AppType::Claude) + .expect("claude manager after switch"); + assert_eq!(manager.current, "new-provider", "current provider updated"); + + let legacy_provider = manager + .providers + .get("old-provider") + .expect("legacy provider still exists"); + assert_eq!( + legacy_provider.settings_config, legacy_live, + "previous provider should receive backfilled live config" + ); +} + +#[test] +fn provider_service_switch_missing_provider_returns_error() { + let mut config = MultiAppConfig::default(); + + let err = ProviderService::switch(&mut config, AppType::Claude, "missing") + .expect_err("switching missing provider should fail"); + match err { + AppError::ProviderNotFound(id) => assert_eq!(id, "missing"), + other => panic!("expected ProviderNotFound, got {other:?}"), + } +} + +#[test] +fn provider_service_switch_codex_missing_auth_returns_error() { + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Codex) + .expect("codex manager"); + manager.providers.insert( + "invalid".to_string(), + Provider::with_id( + "invalid".to_string(), + "Broken Codex".to_string(), + json!({ + "config": "[mcp_servers.test]\ncommand = \"noop\"" + }), + None, + ), + ); + } + + let err = ProviderService::switch(&mut config, AppType::Codex, "invalid") + .expect_err("switching should fail without auth"); + match err { + AppError::Config(msg) => assert!( + msg.contains("auth"), + "expected auth related message, got {msg}" + ), + other => panic!("expected config error, got {other:?}"), + } +} + +#[test] +fn provider_service_delete_codex_removes_provider_and_files() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Codex) + .expect("codex manager"); + manager.current = "keep".to_string(); + manager.providers.insert( + "keep".to_string(), + Provider::with_id( + "keep".to_string(), + "Keep".to_string(), + json!({ + "auth": {"OPENAI_API_KEY": "keep-key"}, + "config": "" + }), + None, + ), + ); + manager.providers.insert( + "to-delete".to_string(), + Provider::with_id( + "to-delete".to_string(), + "DeleteCodex".to_string(), + json!({ + "auth": {"OPENAI_API_KEY": "delete-key"}, + "config": "" + }), + None, + ), + ); + } + + let sanitized = sanitize_provider_name("DeleteCodex"); + let codex_dir = home.join(".codex"); + std::fs::create_dir_all(&codex_dir).expect("create codex dir"); + let auth_path = codex_dir.join(format!("auth-{}.json", sanitized)); + let cfg_path = codex_dir.join(format!("config-{}.toml", sanitized)); + std::fs::write(&auth_path, "{}").expect("seed auth file"); + std::fs::write(&cfg_path, "base_url = \"https://example\"").expect("seed config file"); + + ProviderService::delete(&mut config, AppType::Codex, "to-delete") + .expect("delete provider should succeed"); + + let manager = config.get_manager(&AppType::Codex).expect("codex manager"); + assert!( + !manager.providers.contains_key("to-delete"), + "provider entry should be removed" + ); + assert!( + !auth_path.exists() && !cfg_path.exists(), + "provider-specific files should be deleted" + ); +} + +#[test] +fn provider_service_delete_claude_removes_provider_files() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let home = ensure_test_home(); + + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Claude) + .expect("claude manager"); + manager.current = "keep".to_string(); + manager.providers.insert( + "keep".to_string(), + Provider::with_id( + "keep".to_string(), + "Keep".to_string(), + json!({ + "env": { "ANTHROPIC_API_KEY": "keep-key" } + }), + None, + ), + ); + manager.providers.insert( + "delete".to_string(), + Provider::with_id( + "delete".to_string(), + "DeleteClaude".to_string(), + json!({ + "env": { "ANTHROPIC_API_KEY": "delete-key" } + }), + None, + ), + ); + } + + let sanitized = sanitize_provider_name("DeleteClaude"); + let claude_dir = home.join(".claude"); + std::fs::create_dir_all(&claude_dir).expect("create claude dir"); + let by_name = claude_dir.join(format!("settings-{}.json", sanitized)); + let by_id = claude_dir.join("settings-delete.json"); + std::fs::write(&by_name, "{}").expect("seed settings by name"); + std::fs::write(&by_id, "{}").expect("seed settings by id"); + + ProviderService::delete(&mut config, AppType::Claude, "delete") + .expect("delete claude provider"); + + let manager = config + .get_manager(&AppType::Claude) + .expect("claude manager"); + assert!( + !manager.providers.contains_key("delete"), + "claude provider should be removed" + ); + assert!( + !by_name.exists() && !by_id.exists(), + "provider config files should be deleted" + ); +} + +#[test] +fn provider_service_delete_current_provider_returns_error() { + let mut config = MultiAppConfig::default(); + { + let manager = config + .get_manager_mut(&AppType::Claude) + .expect("claude manager"); + manager.current = "keep".to_string(); + manager.providers.insert( + "keep".to_string(), + Provider::with_id( + "keep".to_string(), + "Keep".to_string(), + json!({ + "env": { "ANTHROPIC_API_KEY": "keep-key" } + }), + None, + ), + ); + } + + let err = ProviderService::delete(&mut config, AppType::Claude, "keep") + .expect_err("deleting current provider should fail"); + match err { + AppError::Config(msg) => assert!( + msg.contains("不能删除当前正在使用的供应商"), + "unexpected message: {msg}" + ), + other => panic!("expected Config error, got {other:?}"), + } +}