diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index d393c17..384802a 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -51,7 +51,7 @@ url = "2.5" auto-launch = "0.5" once_cell = "1.21.3" base64 = "0.22" -rusqlite = { version = "0.31", features = ["bundled"] } +rusqlite = { version = "0.31", features = ["bundled", "backup"] } indexmap = { version = "2", features = ["serde"] } [target.'cfg(any(target_os = "macos", target_os = "windows", target_os = "linux"))'.dependencies] diff --git a/src-tauri/src/database.rs b/src-tauri/src/database.rs index 78fdaef..434d1f5 100644 --- a/src-tauri/src/database.rs +++ b/src-tauri/src/database.rs @@ -4,11 +4,17 @@ use crate::error::AppError; use crate::prompt::Prompt; use crate::provider::{Provider, ProviderMeta}; use crate::services::skill::{SkillRepo, SkillState}; +use chrono::Utc; use indexmap::IndexMap; -use rusqlite::{params, Connection, Result}; +use rusqlite::backup::Backup; +use rusqlite::types::ValueRef; +use rusqlite::{params, Connection}; use serde::Serialize; use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; use std::sync::Mutex; +use tempfile::NamedTempFile; /// 安全地序列化 JSON,避免 unwrap panic fn to_json_string(value: &T) -> Result { @@ -25,6 +31,8 @@ macro_rules! lock_conn { }; } +const DB_BACKUP_RETAIN: usize = 10; + pub struct Database { // 使用 Mutex 包装 Connection 以支持在多线程环境(如 Tauri State)中共享 // rusqlite::Connection 本身不是 Sync 的 @@ -57,7 +65,10 @@ impl Database { fn create_tables(&self) -> Result<(), AppError> { let conn = lock_conn!(self.conn); + Self::create_tables_on_conn(&conn) + } + fn create_tables_on_conn(conn: &Connection) -> Result<(), AppError> { // 1. Providers 表 conn.execute( "CREATE TABLE IF NOT EXISTS providers ( @@ -166,6 +177,308 @@ impl Database { Ok(()) } + /// 创建内存快照以避免长时间持有数据库锁 + fn snapshot_to_memory(&self) -> Result { + let conn = lock_conn!(self.conn); + let mut snapshot = + Connection::open_in_memory().map_err(|e| AppError::Database(e.to_string()))?; + + { + let backup = + Backup::new(&conn, &mut snapshot).map_err(|e| AppError::Database(e.to_string()))?; + backup + .step(-1) + .map_err(|e| AppError::Database(e.to_string()))?; + } + + Ok(snapshot) + } + + /// 导出为 SQLite 兼容的 SQL 文本 + pub fn export_sql(&self, target_path: &Path) -> Result<(), AppError> { + let snapshot = self.snapshot_to_memory()?; + let dump = Self::dump_sql(&snapshot)?; + + if let Some(parent) = target_path.parent() { + fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?; + } + + crate::config::atomic_write(target_path, dump.as_bytes()) + } + + /// 从 SQL 文件导入,返回生成的备份 ID(若无备份则为空字符串) + pub fn import_sql(&self, source_path: &Path) -> Result { + if !source_path.exists() { + return Err(AppError::InvalidInput(format!( + "SQL 文件不存在: {}", + source_path.display() + ))); + } + + let sql_raw = fs::read_to_string(source_path).map_err(|e| AppError::io(source_path, e))?; + let sql_content = Self::sanitize_import_sql(&sql_raw); + + // 导入前备份现有数据库 + let backup_path = self.backup_database_file()?; + + // 在临时数据库执行导入,确保失败不会污染主库 + let temp_file = NamedTempFile::new().map_err(|e| AppError::IoContext { + context: "创建临时数据库文件失败".to_string(), + source: e, + })?; + let temp_path = temp_file.path().to_path_buf(); + let temp_conn = + Connection::open(&temp_path).map_err(|e| AppError::Database(e.to_string()))?; + + temp_conn + .execute_batch(&sql_content) + .map_err(|e| AppError::Database(format!("执行 SQL 导入失败: {e}")))?; + + // 补齐缺失表/索引并进行基础校验 + Self::create_tables_on_conn(&temp_conn)?; + Self::validate_basic_state(&temp_conn)?; + + // 使用 Backup 将临时库原子写回主库 + { + let mut main_conn = lock_conn!(self.conn); + let backup = Backup::new(&temp_conn, &mut *main_conn) + .map_err(|e| AppError::Database(e.to_string()))?; + backup + .step(-1) + .map_err(|e| AppError::Database(e.to_string()))?; + } + + let backup_id = backup_path + .and_then(|p| p.file_stem().map(|s| s.to_string_lossy().to_string())) + .unwrap_or_default(); + + Ok(backup_id) + } + + /// 移除 SQLite 保留对象相关语句(如 sqlite_sequence),避免导入报错 + fn sanitize_import_sql(sql: &str) -> String { + let mut cleaned = String::new(); + let lower_keyword = "sqlite_sequence"; + + for stmt in sql.split(';') { + let trimmed = stmt.trim(); + if trimmed.is_empty() { + continue; + } + + if trimmed.to_ascii_lowercase().contains(lower_keyword) { + continue; + } + + cleaned.push_str(trimmed); + cleaned.push_str(";\n"); + } + + cleaned + } + + /// 生成一致性快照备份,返回备份文件路径(不存在主库时返回 None) + fn backup_database_file(&self) -> Result, AppError> { + let db_path = get_app_config_dir().join("cc-switch.db"); + if !db_path.exists() { + return Ok(None); + } + + let backup_dir = db_path + .parent() + .ok_or_else(|| AppError::Config("无效的数据库路径".to_string()))? + .join("backups"); + + fs::create_dir_all(&backup_dir).map_err(|e| AppError::io(&backup_dir, e))?; + + let backup_id = format!("db_backup_{}", Utc::now().format("%Y%m%d_%H%M%S")); + let backup_path = backup_dir.join(format!("{backup_id}.db")); + + { + let conn = lock_conn!(self.conn); + let mut dest_conn = + Connection::open(&backup_path).map_err(|e| AppError::Database(e.to_string()))?; + let backup = Backup::new(&conn, &mut dest_conn) + .map_err(|e| AppError::Database(e.to_string()))?; + backup + .step(-1) + .map_err(|e| AppError::Database(e.to_string()))?; + } + + Self::cleanup_db_backups(&backup_dir)?; + Ok(Some(backup_path)) + } + + fn cleanup_db_backups(dir: &Path) -> Result<(), AppError> { + let entries = match fs::read_dir(dir) { + Ok(iter) => iter + .filter_map(|entry| entry.ok()) + .filter(|entry| { + entry + .path() + .extension() + .map(|ext| ext == "db") + .unwrap_or(false) + }) + .collect::>(), + Err(_) => return Ok(()), + }; + + if entries.len() <= DB_BACKUP_RETAIN { + return Ok(()); + } + + let remove_count = entries.len().saturating_sub(DB_BACKUP_RETAIN); + let mut sorted = entries; + sorted.sort_by_key(|entry| entry.metadata().and_then(|m| m.modified()).ok()); + + for entry in sorted.into_iter().take(remove_count) { + if let Err(err) = fs::remove_file(entry.path()) { + log::warn!("删除旧数据库备份失败 {}: {}", entry.path().display(), err); + } + } + Ok(()) + } + + fn validate_basic_state(conn: &Connection) -> Result<(), AppError> { + let provider_count: i64 = conn + .query_row("SELECT COUNT(*) FROM providers", [], |row| row.get(0)) + .map_err(|e| AppError::Database(e.to_string()))?; + let mcp_count: i64 = conn + .query_row("SELECT COUNT(*) FROM mcp_servers", [], |row| row.get(0)) + .map_err(|e| AppError::Database(e.to_string()))?; + + if provider_count == 0 && mcp_count == 0 { + return Err(AppError::Config( + "导入的 SQL 未包含有效的供应商或 MCP 数据".to_string(), + )); + } + Ok(()) + } + + fn dump_sql(conn: &Connection) -> Result { + let mut output = String::new(); + let timestamp = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); + let user_version: i64 = conn + .query_row("PRAGMA user_version;", [], |row| row.get(0)) + .unwrap_or(0); + + output.push_str(&format!( + "-- CC Switch SQLite 导出\n-- 生成时间: {timestamp}\n-- user_version: {user_version}\n" + )); + output.push_str("PRAGMA foreign_keys=OFF;\n"); + output.push_str(&format!("PRAGMA user_version={user_version};\n")); + output.push_str("BEGIN TRANSACTION;\n"); + + // 导出 schema + let mut stmt = conn + .prepare( + "SELECT type, name, tbl_name, sql + FROM sqlite_master + WHERE sql NOT NULL AND type IN ('table','index','trigger','view') + ORDER BY type='table' DESC, name", + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + let mut tables = Vec::new(); + let mut rows = stmt + .query([]) + .map_err(|e| AppError::Database(e.to_string()))?; + while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? { + let obj_type: String = row.get(0).map_err(|e| AppError::Database(e.to_string()))?; + let name: String = row.get(1).map_err(|e| AppError::Database(e.to_string()))?; + let sql: String = row.get(3).map_err(|e| AppError::Database(e.to_string()))?; + + // 跳过 SQLite 内部对象(如 sqlite_sequence) + if name.starts_with("sqlite_") { + continue; + } + + output.push_str(&sql); + output.push_str(";\n"); + + if obj_type == "table" && !name.starts_with("sqlite_") { + tables.push(name); + } + } + + // 导出数据 + for table in tables { + let columns = Self::get_table_columns(conn, &table)?; + if columns.is_empty() { + continue; + } + + let mut stmt = conn + .prepare(&format!("SELECT * FROM \"{table}\"")) + .map_err(|e| AppError::Database(e.to_string()))?; + let mut rows = stmt + .query([]) + .map_err(|e| AppError::Database(e.to_string()))?; + + while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? { + let mut values = Vec::with_capacity(columns.len()); + for idx in 0..columns.len() { + let value = row + .get_ref(idx) + .map_err(|e| AppError::Database(e.to_string()))?; + values.push(Self::format_sql_value(value)?); + } + + let cols = columns + .iter() + .map(|c| format!("\"{c}\"")) + .collect::>() + .join(", "); + output.push_str(&format!( + "INSERT INTO \"{table}\" ({cols}) VALUES ({});\n", + values.join(", ") + )); + } + } + + output.push_str("COMMIT;\nPRAGMA foreign_keys=ON;\n"); + Ok(output) + } + + fn get_table_columns(conn: &Connection, table: &str) -> Result, AppError> { + let mut stmt = conn + .prepare(&format!("PRAGMA table_info(\"{table}\")")) + .map_err(|e| AppError::Database(e.to_string()))?; + let iter = stmt + .query_map([], |row| row.get::<_, String>(1)) + .map_err(|e| AppError::Database(e.to_string()))?; + + let mut columns = Vec::new(); + for col in iter { + columns.push(col.map_err(|e| AppError::Database(e.to_string()))?); + } + Ok(columns) + } + + fn format_sql_value(value: ValueRef<'_>) -> Result { + match value { + ValueRef::Null => Ok("NULL".to_string()), + ValueRef::Integer(i) => Ok(i.to_string()), + ValueRef::Real(f) => Ok(f.to_string()), + ValueRef::Text(t) => { + let text = std::str::from_utf8(t) + .map_err(|e| AppError::Database(format!("文本字段不是有效的 UTF-8: {e}")))?; + let escaped = text.replace('\'', "''"); + Ok(format!("'{escaped}'")) + } + ValueRef::Blob(bytes) => { + let mut s = String::from("X'"); + for b in bytes { + use std::fmt::Write; + let _ = write!(&mut s, "{b:02X}"); + } + s.push('\''); + Ok(s) + } + } + } + /// 从 MultiAppConfig 迁移数据 pub fn migrate_from_json(&self, config: &MultiAppConfig) -> Result<(), AppError> { let mut conn = lock_conn!(self.conn);