feat(backend): add database SQL export/import with backup

- Enable rusqlite backup feature for SQL dump support
- Implement export_sql to generate SQLite-compatible SQL dumps
- Implement import_sql with automatic backup before import
- Add snapshot_to_memory to avoid long-held database locks
- Add backup rotation to retain latest 10 backups
- Support atomic import with rollback on failure
This commit is contained in:
YoVinchen
2025-11-24 00:45:34 +08:00
parent 7aa381cbb7
commit 6443dc897d
2 changed files with 315 additions and 2 deletions

View File

@@ -51,7 +51,7 @@ url = "2.5"
auto-launch = "0.5"
once_cell = "1.21.3"
base64 = "0.22"
rusqlite = { version = "0.31", features = ["bundled"] }
rusqlite = { version = "0.31", features = ["bundled", "backup"] }
indexmap = { version = "2", features = ["serde"] }
[target.'cfg(any(target_os = "macos", target_os = "windows", target_os = "linux"))'.dependencies]

View File

@@ -4,11 +4,17 @@ use crate::error::AppError;
use crate::prompt::Prompt;
use crate::provider::{Provider, ProviderMeta};
use crate::services::skill::{SkillRepo, SkillState};
use chrono::Utc;
use indexmap::IndexMap;
use rusqlite::{params, Connection, Result};
use rusqlite::backup::Backup;
use rusqlite::types::ValueRef;
use rusqlite::{params, Connection};
use serde::Serialize;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use tempfile::NamedTempFile;
/// 安全地序列化 JSON避免 unwrap panic
fn to_json_string<T: Serialize>(value: &T) -> Result<String, AppError> {
@@ -25,6 +31,8 @@ macro_rules! lock_conn {
};
}
const DB_BACKUP_RETAIN: usize = 10;
pub struct Database {
// 使用 Mutex 包装 Connection 以支持在多线程环境(如 Tauri State中共享
// rusqlite::Connection 本身不是 Sync 的
@@ -57,7 +65,10 @@ impl Database {
fn create_tables(&self) -> Result<(), AppError> {
let conn = lock_conn!(self.conn);
Self::create_tables_on_conn(&conn)
}
fn create_tables_on_conn(conn: &Connection) -> Result<(), AppError> {
// 1. Providers 表
conn.execute(
"CREATE TABLE IF NOT EXISTS providers (
@@ -166,6 +177,308 @@ impl Database {
Ok(())
}
/// 创建内存快照以避免长时间持有数据库锁
fn snapshot_to_memory(&self) -> Result<Connection, AppError> {
let conn = lock_conn!(self.conn);
let mut snapshot =
Connection::open_in_memory().map_err(|e| AppError::Database(e.to_string()))?;
{
let backup =
Backup::new(&conn, &mut snapshot).map_err(|e| AppError::Database(e.to_string()))?;
backup
.step(-1)
.map_err(|e| AppError::Database(e.to_string()))?;
}
Ok(snapshot)
}
/// 导出为 SQLite 兼容的 SQL 文本
pub fn export_sql(&self, target_path: &Path) -> Result<(), AppError> {
let snapshot = self.snapshot_to_memory()?;
let dump = Self::dump_sql(&snapshot)?;
if let Some(parent) = target_path.parent() {
fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?;
}
crate::config::atomic_write(target_path, dump.as_bytes())
}
/// 从 SQL 文件导入,返回生成的备份 ID若无备份则为空字符串
pub fn import_sql(&self, source_path: &Path) -> Result<String, AppError> {
if !source_path.exists() {
return Err(AppError::InvalidInput(format!(
"SQL 文件不存在: {}",
source_path.display()
)));
}
let sql_raw = fs::read_to_string(source_path).map_err(|e| AppError::io(source_path, e))?;
let sql_content = Self::sanitize_import_sql(&sql_raw);
// 导入前备份现有数据库
let backup_path = self.backup_database_file()?;
// 在临时数据库执行导入,确保失败不会污染主库
let temp_file = NamedTempFile::new().map_err(|e| AppError::IoContext {
context: "创建临时数据库文件失败".to_string(),
source: e,
})?;
let temp_path = temp_file.path().to_path_buf();
let temp_conn =
Connection::open(&temp_path).map_err(|e| AppError::Database(e.to_string()))?;
temp_conn
.execute_batch(&sql_content)
.map_err(|e| AppError::Database(format!("执行 SQL 导入失败: {e}")))?;
// 补齐缺失表/索引并进行基础校验
Self::create_tables_on_conn(&temp_conn)?;
Self::validate_basic_state(&temp_conn)?;
// 使用 Backup 将临时库原子写回主库
{
let mut main_conn = lock_conn!(self.conn);
let backup = Backup::new(&temp_conn, &mut *main_conn)
.map_err(|e| AppError::Database(e.to_string()))?;
backup
.step(-1)
.map_err(|e| AppError::Database(e.to_string()))?;
}
let backup_id = backup_path
.and_then(|p| p.file_stem().map(|s| s.to_string_lossy().to_string()))
.unwrap_or_default();
Ok(backup_id)
}
/// 移除 SQLite 保留对象相关语句(如 sqlite_sequence避免导入报错
fn sanitize_import_sql(sql: &str) -> String {
let mut cleaned = String::new();
let lower_keyword = "sqlite_sequence";
for stmt in sql.split(';') {
let trimmed = stmt.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.to_ascii_lowercase().contains(lower_keyword) {
continue;
}
cleaned.push_str(trimmed);
cleaned.push_str(";\n");
}
cleaned
}
/// 生成一致性快照备份,返回备份文件路径(不存在主库时返回 None
fn backup_database_file(&self) -> Result<Option<PathBuf>, AppError> {
let db_path = get_app_config_dir().join("cc-switch.db");
if !db_path.exists() {
return Ok(None);
}
let backup_dir = db_path
.parent()
.ok_or_else(|| AppError::Config("无效的数据库路径".to_string()))?
.join("backups");
fs::create_dir_all(&backup_dir).map_err(|e| AppError::io(&backup_dir, e))?;
let backup_id = format!("db_backup_{}", Utc::now().format("%Y%m%d_%H%M%S"));
let backup_path = backup_dir.join(format!("{backup_id}.db"));
{
let conn = lock_conn!(self.conn);
let mut dest_conn =
Connection::open(&backup_path).map_err(|e| AppError::Database(e.to_string()))?;
let backup = Backup::new(&conn, &mut dest_conn)
.map_err(|e| AppError::Database(e.to_string()))?;
backup
.step(-1)
.map_err(|e| AppError::Database(e.to_string()))?;
}
Self::cleanup_db_backups(&backup_dir)?;
Ok(Some(backup_path))
}
fn cleanup_db_backups(dir: &Path) -> Result<(), AppError> {
let entries = match fs::read_dir(dir) {
Ok(iter) => iter
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry
.path()
.extension()
.map(|ext| ext == "db")
.unwrap_or(false)
})
.collect::<Vec<_>>(),
Err(_) => return Ok(()),
};
if entries.len() <= DB_BACKUP_RETAIN {
return Ok(());
}
let remove_count = entries.len().saturating_sub(DB_BACKUP_RETAIN);
let mut sorted = entries;
sorted.sort_by_key(|entry| entry.metadata().and_then(|m| m.modified()).ok());
for entry in sorted.into_iter().take(remove_count) {
if let Err(err) = fs::remove_file(entry.path()) {
log::warn!("删除旧数据库备份失败 {}: {}", entry.path().display(), err);
}
}
Ok(())
}
fn validate_basic_state(conn: &Connection) -> Result<(), AppError> {
let provider_count: i64 = conn
.query_row("SELECT COUNT(*) FROM providers", [], |row| row.get(0))
.map_err(|e| AppError::Database(e.to_string()))?;
let mcp_count: i64 = conn
.query_row("SELECT COUNT(*) FROM mcp_servers", [], |row| row.get(0))
.map_err(|e| AppError::Database(e.to_string()))?;
if provider_count == 0 && mcp_count == 0 {
return Err(AppError::Config(
"导入的 SQL 未包含有效的供应商或 MCP 数据".to_string(),
));
}
Ok(())
}
fn dump_sql(conn: &Connection) -> Result<String, AppError> {
let mut output = String::new();
let timestamp = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
let user_version: i64 = conn
.query_row("PRAGMA user_version;", [], |row| row.get(0))
.unwrap_or(0);
output.push_str(&format!(
"-- CC Switch SQLite 导出\n-- 生成时间: {timestamp}\n-- user_version: {user_version}\n"
));
output.push_str("PRAGMA foreign_keys=OFF;\n");
output.push_str(&format!("PRAGMA user_version={user_version};\n"));
output.push_str("BEGIN TRANSACTION;\n");
// 导出 schema
let mut stmt = conn
.prepare(
"SELECT type, name, tbl_name, sql
FROM sqlite_master
WHERE sql NOT NULL AND type IN ('table','index','trigger','view')
ORDER BY type='table' DESC, name",
)
.map_err(|e| AppError::Database(e.to_string()))?;
let mut tables = Vec::new();
let mut rows = stmt
.query([])
.map_err(|e| AppError::Database(e.to_string()))?;
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
let obj_type: String = row.get(0).map_err(|e| AppError::Database(e.to_string()))?;
let name: String = row.get(1).map_err(|e| AppError::Database(e.to_string()))?;
let sql: String = row.get(3).map_err(|e| AppError::Database(e.to_string()))?;
// 跳过 SQLite 内部对象(如 sqlite_sequence
if name.starts_with("sqlite_") {
continue;
}
output.push_str(&sql);
output.push_str(";\n");
if obj_type == "table" && !name.starts_with("sqlite_") {
tables.push(name);
}
}
// 导出数据
for table in tables {
let columns = Self::get_table_columns(conn, &table)?;
if columns.is_empty() {
continue;
}
let mut stmt = conn
.prepare(&format!("SELECT * FROM \"{table}\""))
.map_err(|e| AppError::Database(e.to_string()))?;
let mut rows = stmt
.query([])
.map_err(|e| AppError::Database(e.to_string()))?;
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
let mut values = Vec::with_capacity(columns.len());
for idx in 0..columns.len() {
let value = row
.get_ref(idx)
.map_err(|e| AppError::Database(e.to_string()))?;
values.push(Self::format_sql_value(value)?);
}
let cols = columns
.iter()
.map(|c| format!("\"{c}\""))
.collect::<Vec<_>>()
.join(", ");
output.push_str(&format!(
"INSERT INTO \"{table}\" ({cols}) VALUES ({});\n",
values.join(", ")
));
}
}
output.push_str("COMMIT;\nPRAGMA foreign_keys=ON;\n");
Ok(output)
}
fn get_table_columns(conn: &Connection, table: &str) -> Result<Vec<String>, AppError> {
let mut stmt = conn
.prepare(&format!("PRAGMA table_info(\"{table}\")"))
.map_err(|e| AppError::Database(e.to_string()))?;
let iter = stmt
.query_map([], |row| row.get::<_, String>(1))
.map_err(|e| AppError::Database(e.to_string()))?;
let mut columns = Vec::new();
for col in iter {
columns.push(col.map_err(|e| AppError::Database(e.to_string()))?);
}
Ok(columns)
}
fn format_sql_value(value: ValueRef<'_>) -> Result<String, AppError> {
match value {
ValueRef::Null => Ok("NULL".to_string()),
ValueRef::Integer(i) => Ok(i.to_string()),
ValueRef::Real(f) => Ok(f.to_string()),
ValueRef::Text(t) => {
let text = std::str::from_utf8(t)
.map_err(|e| AppError::Database(format!("文本字段不是有效的 UTF-8: {e}")))?;
let escaped = text.replace('\'', "''");
Ok(format!("'{escaped}'"))
}
ValueRef::Blob(bytes) => {
let mut s = String::from("X'");
for b in bytes {
use std::fmt::Write;
let _ = write!(&mut s, "{b:02X}");
}
s.push('\'');
Ok(s)
}
}
}
/// 从 MultiAppConfig 迁移数据
pub fn migrate_from_json(&self, config: &MultiAppConfig) -> Result<(), AppError> {
let mut conn = lock_conn!(self.conn);