From 66079f3adc6a2615f571ee4ebba5f442b343f5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E1=B4=80=E1=B4=8D=E1=B4=9B=E1=B4=8F=E1=B4=80=E1=B4=87?= =?UTF-8?q?=CA=80?= Date: Wed, 6 Aug 2025 17:20:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20sqlite=20=E5=BC=80=E5=90=AF=20Wal?= =?UTF-8?q?=EF=BC=8C=E7=A7=BB=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84?= =?UTF-8?q?=20Arc=EF=BC=8C=E5=A6=A5=E5=96=84=E9=87=8A=E6=94=BE=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=20(#421)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/bili_sync/src/api/routes/config/mod.rs | 4 +- .../bili_sync/src/api/routes/dashboard/mod.rs | 14 +++--- crates/bili_sync/src/api/routes/me/mod.rs | 12 ++--- .../src/api/routes/video_sources/mod.rs | 44 +++++++++---------- crates/bili_sync/src/api/routes/videos/mod.rs | 37 +++++++--------- crates/bili_sync/src/database.rs | 9 ++-- crates/bili_sync/src/main.rs | 21 +++++---- crates/bili_sync/src/task/http_server.rs | 2 +- crates/bili_sync/src/task/video_downloader.rs | 2 +- 9 files changed, 72 insertions(+), 73 deletions(-) diff --git a/crates/bili_sync/src/api/routes/config/mod.rs b/crates/bili_sync/src/api/routes/config/mod.rs index b859fb0..abb1d35 100644 --- a/crates/bili_sync/src/api/routes/config/mod.rs +++ b/crates/bili_sync/src/api/routes/config/mod.rs @@ -22,7 +22,7 @@ pub async fn get_config() -> Result>, ApiError> { /// 更新全局配置 pub async fn update_config( - Extension(db): Extension>, + Extension(db): Extension, ValidatedJson(config): ValidatedJson, ) -> Result>, ApiError> { let Some(_lock) = TASK_STATUS_NOTIFIER.detect_running() else { @@ -30,7 +30,7 @@ pub async fn update_config( return Err(InnerApiError::BadRequest("下载任务正在运行,无法修改配置".to_string()).into()); }; config.check()?; - let new_config = VersionedConfig::get().update(config, db.as_ref()).await?; + let new_config = VersionedConfig::get().update(config, &db).await?; drop(_lock); Ok(ApiResponse::ok(new_config)) } diff --git a/crates/bili_sync/src/api/routes/dashboard/mod.rs b/crates/bili_sync/src/api/routes/dashboard/mod.rs index 374fe64..747d32f 100644 --- a/crates/bili_sync/src/api/routes/dashboard/mod.rs +++ b/crates/bili_sync/src/api/routes/dashboard/mod.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use axum::routing::get; use axum::{Extension, Router}; use bili_sync_entity::*; @@ -14,21 +12,21 @@ pub(super) fn router() -> Router { } async fn get_dashboard( - Extension(db): Extension>, + Extension(db): Extension, ) -> Result, ApiError> { let (enabled_favorites, enabled_collections, enabled_submissions, enabled_watch_later, videos_by_day) = tokio::try_join!( favorite::Entity::find() .filter(favorite::Column::Enabled.eq(true)) - .count(db.as_ref()), + .count(&db), collection::Entity::find() .filter(collection::Column::Enabled.eq(true)) - .count(db.as_ref()), + .count(&db), submission::Entity::find() .filter(submission::Column::Enabled.eq(true)) - .count(db.as_ref()), + .count(&db), watch_later::Entity::find() .filter(watch_later::Column::Enabled.eq(true)) - .count(db.as_ref()), + .count(&db), DayCountPair::find_by_statement(Statement::from_string( db.get_database_backend(), // 用 SeaORM 太复杂了,直接写个裸 SQL @@ -55,7 +53,7 @@ ORDER BY dates.day; " )) - .all(db.as_ref()), + .all(&db), )?; return Ok(ApiResponse::ok(DashBoardResponse { enabled_favorites, diff --git a/crates/bili_sync/src/api/routes/me/mod.rs b/crates/bili_sync/src/api/routes/me/mod.rs index 6beb9a5..e762081 100644 --- a/crates/bili_sync/src/api/routes/me/mod.rs +++ b/crates/bili_sync/src/api/routes/me/mod.rs @@ -25,7 +25,7 @@ pub(super) fn router() -> Router { /// 获取当前用户创建的收藏夹 pub async fn get_created_favorites( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, ) -> Result, ApiError> { let me = Me::new(bili_client.as_ref()); @@ -40,7 +40,7 @@ pub async fn get_created_favorites( .column(favorite::Column::FId) .filter(favorite::Column::FId.is_in(bili_fids)) .into_tuple() - .all(db.as_ref()) + .all(&db) .await?; let subscribed_set: HashSet = subscribed_fids.into_iter().collect(); @@ -64,7 +64,7 @@ pub async fn get_created_favorites( /// 获取当前用户收藏的合集 pub async fn get_followed_collections( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, Query(params): Query, ) -> Result, ApiError> { @@ -80,7 +80,7 @@ pub async fn get_followed_collections( .column(collection::Column::SId) .filter(collection::Column::SId.is_in(bili_sids)) .into_tuple() - .all(db.as_ref()) + .all(&db) .await?; let subscribed_set: HashSet = subscribed_ids.into_iter().collect(); @@ -106,7 +106,7 @@ pub async fn get_followed_collections( /// 获取当前用户关注的 UP 主 pub async fn get_followed_uppers( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, Query(params): Query, ) -> Result, ApiError> { @@ -121,7 +121,7 @@ pub async fn get_followed_uppers( .column(submission::Column::UpperId) .filter(submission::Column::UpperId.is_in(bili_uid)) .into_tuple() - .all(db.as_ref()) + .all(&db) .await?; let subscribed_set: HashSet = subscribed_ids.into_iter().collect(); diff --git a/crates/bili_sync/src/api/routes/video_sources/mod.rs b/crates/bili_sync/src/api/routes/video_sources/mod.rs index b2f18b8..86195d5 100644 --- a/crates/bili_sync/src/api/routes/video_sources/mod.rs +++ b/crates/bili_sync/src/api/routes/video_sources/mod.rs @@ -30,31 +30,31 @@ pub(super) fn router() -> Router { /// 列出所有视频来源 pub async fn get_video_sources( - Extension(db): Extension>, + Extension(db): Extension, ) -> Result, ApiError> { let (collection, favorite, submission, mut watch_later) = tokio::try_join!( collection::Entity::find() .select_only() .columns([collection::Column::Id, collection::Column::Name]) .into_model::() - .all(db.as_ref()), + .all(&db), favorite::Entity::find() .select_only() .columns([favorite::Column::Id, favorite::Column::Name]) .into_model::() - .all(db.as_ref()), + .all(&db), submission::Entity::find() .select_only() .column(submission::Column::Id) .column_as(submission::Column::UpperName, "name") .into_model::() - .all(db.as_ref()), + .all(&db), watch_later::Entity::find() .select_only() .column(watch_later::Column::Id) .column_as(Expr::value("稍后再看"), "name") .into_model::() - .all(db.as_ref()) + .all(&db) )?; // watch_later 是一个特殊的视频来源,如果不存在则添加一个默认项 if watch_later.is_empty() { @@ -73,7 +73,7 @@ pub async fn get_video_sources( /// 获取视频来源详情 pub async fn get_video_sources_details( - Extension(db): Extension>, + Extension(db): Extension, ) -> Result, ApiError> { let (collections, favorites, submissions, mut watch_later) = tokio::try_join!( collection::Entity::find() @@ -85,7 +85,7 @@ pub async fn get_video_sources_details( collection::Column::Enabled ]) .into_model::() - .all(db.as_ref()), + .all(&db), favorite::Entity::find() .select_only() .columns([ @@ -95,21 +95,21 @@ pub async fn get_video_sources_details( favorite::Column::Enabled ]) .into_model::() - .all(db.as_ref()), + .all(&db), submission::Entity::find() .select_only() .column(submission::Column::Id) .column_as(submission::Column::UpperName, "name") .columns([submission::Column::Path, submission::Column::Enabled]) .into_model::() - .all(db.as_ref()), + .all(&db), watch_later::Entity::find() .select_only() .column(watch_later::Column::Id) .column_as(Expr::value("稍后再看"), "name") .columns([watch_later::Column::Path, watch_later::Column::Enabled]) .into_model::() - .all(db.as_ref()) + .all(&db) )?; if watch_later.is_empty() { watch_later.push(VideoSourceDetail { @@ -130,29 +130,29 @@ pub async fn get_video_sources_details( /// 更新视频来源 pub async fn update_video_source( Path((source_type, id)): Path<(String, i32)>, - Extension(db): Extension>, + Extension(db): Extension, ValidatedJson(request): ValidatedJson, ) -> Result, ApiError> { let active_model = match source_type.as_str() { - "collections" => collection::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| { + "collections" => collection::Entity::find_by_id(id).one(&db).await?.map(|model| { let mut active_model: collection::ActiveModel = model.into(); active_model.path = Set(request.path); active_model.enabled = Set(request.enabled); _ActiveModel::Collection(active_model) }), - "favorites" => favorite::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| { + "favorites" => favorite::Entity::find_by_id(id).one(&db).await?.map(|model| { let mut active_model: favorite::ActiveModel = model.into(); active_model.path = Set(request.path); active_model.enabled = Set(request.enabled); _ActiveModel::Favorite(active_model) }), - "submissions" => submission::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| { + "submissions" => submission::Entity::find_by_id(id).one(&db).await?.map(|model| { let mut active_model: submission::ActiveModel = model.into(); active_model.path = Set(request.path); active_model.enabled = Set(request.enabled); _ActiveModel::Submission(active_model) }), - "watch_later" => match watch_later::Entity::find_by_id(id).one(db.as_ref()).await? { + "watch_later" => match watch_later::Entity::find_by_id(id).one(&db).await? { // 稍后再看需要做特殊处理,get 时如果稍后再看不存在返回的是 id 为 1 的假记录 // 因此此处可能是更新也可能是插入,做个额外的处理 Some(model) => { @@ -180,13 +180,13 @@ pub async fn update_video_source( let Some(active_model) = active_model else { return Err(InnerApiError::NotFound(id).into()); }; - active_model.save(db.as_ref()).await?; + active_model.save(&db).await?; Ok(ApiResponse::ok(true)) } /// 新增收藏夹订阅 pub async fn insert_favorite( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, ValidatedJson(request): ValidatedJson, ) -> Result, ApiError> { @@ -199,14 +199,14 @@ pub async fn insert_favorite( enabled: Set(true), ..Default::default() }) - .exec(db.as_ref()) + .exec(&db) .await?; Ok(ApiResponse::ok(true)) } /// 新增合集/列表订阅 pub async fn insert_collection( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, ValidatedJson(request): ValidatedJson, ) -> Result, ApiError> { @@ -228,7 +228,7 @@ pub async fn insert_collection( enabled: Set(true), ..Default::default() }) - .exec(db.as_ref()) + .exec(&db) .await?; Ok(ApiResponse::ok(true)) @@ -236,7 +236,7 @@ pub async fn insert_collection( /// 新增投稿订阅 pub async fn insert_submission( - Extension(db): Extension>, + Extension(db): Extension, Extension(bili_client): Extension>, ValidatedJson(request): ValidatedJson, ) -> Result, ApiError> { @@ -249,7 +249,7 @@ pub async fn insert_submission( enabled: Set(true), ..Default::default() }) - .exec(db.as_ref()) + .exec(&db) .await?; Ok(ApiResponse::ok(true)) } diff --git a/crates/bili_sync/src/api/routes/videos/mod.rs b/crates/bili_sync/src/api/routes/videos/mod.rs index ac4f558..a6b1db5 100644 --- a/crates/bili_sync/src/api/routes/videos/mod.rs +++ b/crates/bili_sync/src/api/routes/videos/mod.rs @@ -1,5 +1,4 @@ use std::collections::HashSet; -use std::sync::Arc; use anyhow::Result; use axum::extract::{Extension, Path, Query}; @@ -31,7 +30,7 @@ pub(super) fn router() -> Router { /// 列出视频的基本信息,支持根据视频来源筛选、名称查找和分页 pub async fn get_videos( - Extension(db): Extension>, + Extension(db): Extension, Query(params): Query, ) -> Result, ApiError> { let mut query = video::Entity::find(); @@ -48,7 +47,7 @@ pub async fn get_videos( if let Some(query_word) = params.query { query = query.filter(video::Column::Name.contains(query_word)); } - let total_count = query.clone().count(db.as_ref()).await?; + let total_count = query.clone().count(&db).await?; let (page, page_size) = if let (Some(page), Some(page_size)) = (params.page, params.page_size) { (page, page_size) } else { @@ -58,7 +57,7 @@ pub async fn get_videos( videos: query .order_by_desc(video::Column::Id) .into_partial_model::() - .paginate(db.as_ref(), page_size) + .paginate(&db, page_size) .fetch_page(page) .await?, total_count, @@ -67,17 +66,15 @@ pub async fn get_videos( pub async fn get_video( Path(id): Path, - Extension(db): Extension>, + Extension(db): Extension, ) -> Result, ApiError> { let (video_info, pages_info) = tokio::try_join!( - video::Entity::find_by_id(id) - .into_partial_model::() - .one(db.as_ref()), + video::Entity::find_by_id(id).into_partial_model::().one(&db), page::Entity::find() .filter(page::Column::VideoId.eq(id)) .order_by_asc(page::Column::Cid) .into_partial_model::() - .all(db.as_ref()) + .all(&db) )?; let Some(video_info) = video_info else { return Err(InnerApiError::NotFound(id).into()); @@ -90,18 +87,16 @@ pub async fn get_video( pub async fn reset_video( Path(id): Path, - Extension(db): Extension>, + Extension(db): Extension, Json(request): Json, ) -> Result, ApiError> { let (video_info, pages_info) = tokio::try_join!( - video::Entity::find_by_id(id) - .into_partial_model::() - .one(db.as_ref()), + video::Entity::find_by_id(id).into_partial_model::().one(&db), page::Entity::find() .filter(page::Column::VideoId.eq(id)) .order_by_asc(page::Column::Cid) .into_partial_model::() - .all(db.as_ref()) + .all(&db) )?; let Some(mut video_info) = video_info else { return Err(InnerApiError::NotFound(id).into()); @@ -150,13 +145,13 @@ pub async fn reset_video( } pub async fn reset_all_videos( - Extension(db): Extension>, + Extension(db): Extension, Json(request): Json, ) -> Result, ApiError> { // 先查询所有视频和页面数据 let (all_videos, all_pages) = tokio::try_join!( - video::Entity::find().into_partial_model::().all(db.as_ref()), - page::Entity::find().into_partial_model::().all(db.as_ref()) + video::Entity::find().into_partial_model::().all(&db), + page::Entity::find().into_partial_model::().all(&db) )?; let resetted_pages_info = all_pages .into_iter() @@ -210,18 +205,16 @@ pub async fn reset_all_videos( pub async fn update_video_status( Path(id): Path, - Extension(db): Extension>, + Extension(db): Extension, ValidatedJson(request): ValidatedJson, ) -> Result, ApiError> { let (video_info, mut pages_info) = tokio::try_join!( - video::Entity::find_by_id(id) - .into_partial_model::() - .one(db.as_ref()), + video::Entity::find_by_id(id).into_partial_model::().one(&db), page::Entity::find() .filter(page::Column::VideoId.eq(id)) .order_by_asc(page::Column::Cid) .into_partial_model::() - .all(db.as_ref()) + .all(&db) )?; let Some(mut video_info) = video_info else { return Err(InnerApiError::NotFound(id).into()); diff --git a/crates/bili_sync/src/database.rs b/crates/bili_sync/src/database.rs index f99603a..1704ec9 100644 --- a/crates/bili_sync/src/database.rs +++ b/crates/bili_sync/src/database.rs @@ -2,7 +2,7 @@ use std::time::Duration; use anyhow::{Context, Result}; use bili_sync_migration::{Migrator, MigratorTrait}; -use sea_orm::sqlx::sqlite::SqliteConnectOptions; +use sea_orm::sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous}; use sea_orm::sqlx::{ConnectOptions as SqlxConnectOptions, Sqlite}; use sea_orm::{ConnectOptions, Database, DatabaseConnection, SqlxSqliteConnector}; @@ -15,7 +15,7 @@ fn database_url() -> String { async fn database_connection() -> Result { let mut option = ConnectOptions::new(database_url()); option - .max_connections(100) + .max_connections(50) .min_connections(5) .acquire_timeout(Duration::from_secs(90)); let connect_option = option @@ -23,7 +23,10 @@ async fn database_connection() -> Result { .parse::() .context("Failed to parse database URL")? .disable_statement_logging() - .busy_timeout(Duration::from_secs(90)); + .busy_timeout(Duration::from_secs(90)) + .journal_mode(SqliteJournalMode::Wal) + .synchronous(SqliteSynchronous::Normal) + .optimize_on_close(true, None); Ok(SqlxSqliteConnector::from_sqlx_sqlite_pool( option .sqlx_pool_options::() diff --git a/crates/bili_sync/src/main.rs b/crates/bili_sync/src/main.rs index f30194b..919c5ed 100644 --- a/crates/bili_sync/src/main.rs +++ b/crates/bili_sync/src/main.rs @@ -47,14 +47,14 @@ async fn main() { if !cfg!(debug_assertions) { spawn_task( "定时下载", - video_downloader(connection, bili_client), + video_downloader(connection.clone(), bili_client), &tracker, token.clone(), ); } tracker.close(); - handle_shutdown(tracker, token).await + handle_shutdown(connection, tracker, token).await } fn spawn_task( @@ -77,7 +77,7 @@ fn spawn_task( } /// 初始化日志系统、打印欢迎信息,初始化数据库连接和全局配置 -async fn init() -> (Arc, LogHelper) { +async fn init() -> (DatabaseConnection, LogHelper) { let (tx, _rx) = tokio::sync::broadcast::channel(30); let log_history = Arc::new(Mutex::new(VecDeque::with_capacity(MAX_HISTORY_LOGS + 1))); let log_writer = LogHelper::new(tx, log_history.clone()); @@ -85,7 +85,7 @@ async fn init() -> (Arc, LogHelper) { init_logger(&ARGS.log_level, Some(log_writer.clone())); info!("欢迎使用 Bili-Sync,当前程序版本:{}", config::version()); info!("项目地址:https://github.com/amtoaer/bili-sync"); - let connection = Arc::new(setup_database().await.expect("数据库初始化失败")); + let connection = setup_database().await.expect("数据库初始化失败"); info!("数据库初始化完成"); VersionedConfig::init(&connection).await.expect("配置初始化失败"); info!("配置初始化完成"); @@ -93,16 +93,21 @@ async fn init() -> (Arc, LogHelper) { (connection, log_writer) } -async fn handle_shutdown(tracker: TaskTracker, token: CancellationToken) { +async fn handle_shutdown(connection: DatabaseConnection, tracker: TaskTracker, token: CancellationToken) { tokio::select! { _ = tracker.wait() => { - error!("所有任务均已终止,程序退出") + error!("所有任务均已终止..") } _ = terminate() => { - info!("接收到终止信号,正在终止任务.."); + info!("接收到终止信号,开始终止任务.."); token.cancel(); tracker.wait().await; - info!("所有任务均已终止,程序退出"); + info!("所有任务均已终止.."); } } + info!("正在关闭数据库连接.."); + match connection.close().await { + Ok(()) => info!("数据库连接已关闭,程序结束"), + Err(e) => error!("关闭数据库连接时遇到错误:{:#},程序异常结束", e), + } } diff --git a/crates/bili_sync/src/task/http_server.rs b/crates/bili_sync/src/task/http_server.rs index fde9db1..307b6ee 100644 --- a/crates/bili_sync/src/task/http_server.rs +++ b/crates/bili_sync/src/task/http_server.rs @@ -21,7 +21,7 @@ use crate::config::VersionedConfig; struct Asset; pub async fn http_server( - database_connection: Arc, + database_connection: DatabaseConnection, bili_client: Arc, log_writer: LogHelper, ) -> Result<()> { diff --git a/crates/bili_sync/src/task/video_downloader.rs b/crates/bili_sync/src/task/video_downloader.rs index cd502fc..db7fe39 100644 --- a/crates/bili_sync/src/task/video_downloader.rs +++ b/crates/bili_sync/src/task/video_downloader.rs @@ -11,7 +11,7 @@ use crate::utils::task_notifier::TASK_STATUS_NOTIFIER; use crate::workflow::process_video_source; /// 启动周期下载视频的任务 -pub async fn video_downloader(connection: Arc, bili_client: Arc) { +pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc) { let mut anchor = chrono::Local::now().date_naive(); loop { info!("开始执行本轮视频下载任务..");