feat: sqlite 开启 Wal,移除不必要的 Arc,妥善释放数据库 (#421)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-08-06 17:20:06 +08:00
committed by GitHub
parent 4f780faf64
commit 66079f3adc
9 changed files with 72 additions and 73 deletions

View File

@@ -22,7 +22,7 @@ pub async fn get_config() -> Result<ApiResponse<Arc<Config>>, ApiError> {
/// 更新全局配置
pub async fn update_config(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
ValidatedJson(config): ValidatedJson<Config>,
) -> Result<ApiResponse<Arc<Config>>, ApiError> {
let Some(_lock) = TASK_STATUS_NOTIFIER.detect_running() else {
@@ -30,7 +30,7 @@ pub async fn update_config(
return Err(InnerApiError::BadRequest("下载任务正在运行,无法修改配置".to_string()).into());
};
config.check()?;
let new_config = VersionedConfig::get().update(config, db.as_ref()).await?;
let new_config = VersionedConfig::get().update(config, &db).await?;
drop(_lock);
Ok(ApiResponse::ok(new_config))
}

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use axum::routing::get;
use axum::{Extension, Router};
use bili_sync_entity::*;
@@ -14,21 +12,21 @@ pub(super) fn router() -> Router {
}
async fn get_dashboard(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
) -> Result<ApiResponse<DashBoardResponse>, ApiError> {
let (enabled_favorites, enabled_collections, enabled_submissions, enabled_watch_later, videos_by_day) = tokio::try_join!(
favorite::Entity::find()
.filter(favorite::Column::Enabled.eq(true))
.count(db.as_ref()),
.count(&db),
collection::Entity::find()
.filter(collection::Column::Enabled.eq(true))
.count(db.as_ref()),
.count(&db),
submission::Entity::find()
.filter(submission::Column::Enabled.eq(true))
.count(db.as_ref()),
.count(&db),
watch_later::Entity::find()
.filter(watch_later::Column::Enabled.eq(true))
.count(db.as_ref()),
.count(&db),
DayCountPair::find_by_statement(Statement::from_string(
db.get_database_backend(),
// 用 SeaORM 太复杂了,直接写个裸 SQL
@@ -55,7 +53,7 @@ ORDER BY
dates.day;
"
))
.all(db.as_ref()),
.all(&db),
)?;
return Ok(ApiResponse::ok(DashBoardResponse {
enabled_favorites,

View File

@@ -25,7 +25,7 @@ pub(super) fn router() -> Router {
/// 获取当前用户创建的收藏夹
pub async fn get_created_favorites(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
) -> Result<ApiResponse<FavoritesResponse>, ApiError> {
let me = Me::new(bili_client.as_ref());
@@ -40,7 +40,7 @@ pub async fn get_created_favorites(
.column(favorite::Column::FId)
.filter(favorite::Column::FId.is_in(bili_fids))
.into_tuple()
.all(db.as_ref())
.all(&db)
.await?;
let subscribed_set: HashSet<i64> = subscribed_fids.into_iter().collect();
@@ -64,7 +64,7 @@ pub async fn get_created_favorites(
/// 获取当前用户收藏的合集
pub async fn get_followed_collections(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
Query(params): Query<FollowedCollectionsRequest>,
) -> Result<ApiResponse<CollectionsResponse>, ApiError> {
@@ -80,7 +80,7 @@ pub async fn get_followed_collections(
.column(collection::Column::SId)
.filter(collection::Column::SId.is_in(bili_sids))
.into_tuple()
.all(db.as_ref())
.all(&db)
.await?;
let subscribed_set: HashSet<i64> = subscribed_ids.into_iter().collect();
@@ -106,7 +106,7 @@ pub async fn get_followed_collections(
/// 获取当前用户关注的 UP 主
pub async fn get_followed_uppers(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
Query(params): Query<FollowedUppersRequest>,
) -> Result<ApiResponse<UppersResponse>, ApiError> {
@@ -121,7 +121,7 @@ pub async fn get_followed_uppers(
.column(submission::Column::UpperId)
.filter(submission::Column::UpperId.is_in(bili_uid))
.into_tuple()
.all(db.as_ref())
.all(&db)
.await?;
let subscribed_set: HashSet<i64> = subscribed_ids.into_iter().collect();

View File

@@ -30,31 +30,31 @@ pub(super) fn router() -> Router {
/// 列出所有视频来源
pub async fn get_video_sources(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
) -> Result<ApiResponse<VideoSourcesResponse>, ApiError> {
let (collection, favorite, submission, mut watch_later) = tokio::try_join!(
collection::Entity::find()
.select_only()
.columns([collection::Column::Id, collection::Column::Name])
.into_model::<VideoSource>()
.all(db.as_ref()),
.all(&db),
favorite::Entity::find()
.select_only()
.columns([favorite::Column::Id, favorite::Column::Name])
.into_model::<VideoSource>()
.all(db.as_ref()),
.all(&db),
submission::Entity::find()
.select_only()
.column(submission::Column::Id)
.column_as(submission::Column::UpperName, "name")
.into_model::<VideoSource>()
.all(db.as_ref()),
.all(&db),
watch_later::Entity::find()
.select_only()
.column(watch_later::Column::Id)
.column_as(Expr::value("稍后再看"), "name")
.into_model::<VideoSource>()
.all(db.as_ref())
.all(&db)
)?;
// watch_later 是一个特殊的视频来源,如果不存在则添加一个默认项
if watch_later.is_empty() {
@@ -73,7 +73,7 @@ pub async fn get_video_sources(
/// 获取视频来源详情
pub async fn get_video_sources_details(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
) -> Result<ApiResponse<VideoSourcesDetailsResponse>, ApiError> {
let (collections, favorites, submissions, mut watch_later) = tokio::try_join!(
collection::Entity::find()
@@ -85,7 +85,7 @@ pub async fn get_video_sources_details(
collection::Column::Enabled
])
.into_model::<VideoSourceDetail>()
.all(db.as_ref()),
.all(&db),
favorite::Entity::find()
.select_only()
.columns([
@@ -95,21 +95,21 @@ pub async fn get_video_sources_details(
favorite::Column::Enabled
])
.into_model::<VideoSourceDetail>()
.all(db.as_ref()),
.all(&db),
submission::Entity::find()
.select_only()
.column(submission::Column::Id)
.column_as(submission::Column::UpperName, "name")
.columns([submission::Column::Path, submission::Column::Enabled])
.into_model::<VideoSourceDetail>()
.all(db.as_ref()),
.all(&db),
watch_later::Entity::find()
.select_only()
.column(watch_later::Column::Id)
.column_as(Expr::value("稍后再看"), "name")
.columns([watch_later::Column::Path, watch_later::Column::Enabled])
.into_model::<VideoSourceDetail>()
.all(db.as_ref())
.all(&db)
)?;
if watch_later.is_empty() {
watch_later.push(VideoSourceDetail {
@@ -130,29 +130,29 @@ pub async fn get_video_sources_details(
/// 更新视频来源
pub async fn update_video_source(
Path((source_type, id)): Path<(String, i32)>,
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
ValidatedJson(request): ValidatedJson<UpdateVideoSourceRequest>,
) -> Result<ApiResponse<bool>, ApiError> {
let active_model = match source_type.as_str() {
"collections" => collection::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| {
"collections" => collection::Entity::find_by_id(id).one(&db).await?.map(|model| {
let mut active_model: collection::ActiveModel = model.into();
active_model.path = Set(request.path);
active_model.enabled = Set(request.enabled);
_ActiveModel::Collection(active_model)
}),
"favorites" => favorite::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| {
"favorites" => favorite::Entity::find_by_id(id).one(&db).await?.map(|model| {
let mut active_model: favorite::ActiveModel = model.into();
active_model.path = Set(request.path);
active_model.enabled = Set(request.enabled);
_ActiveModel::Favorite(active_model)
}),
"submissions" => submission::Entity::find_by_id(id).one(db.as_ref()).await?.map(|model| {
"submissions" => submission::Entity::find_by_id(id).one(&db).await?.map(|model| {
let mut active_model: submission::ActiveModel = model.into();
active_model.path = Set(request.path);
active_model.enabled = Set(request.enabled);
_ActiveModel::Submission(active_model)
}),
"watch_later" => match watch_later::Entity::find_by_id(id).one(db.as_ref()).await? {
"watch_later" => match watch_later::Entity::find_by_id(id).one(&db).await? {
// 稍后再看需要做特殊处理get 时如果稍后再看不存在返回的是 id 为 1 的假记录
// 因此此处可能是更新也可能是插入,做个额外的处理
Some(model) => {
@@ -180,13 +180,13 @@ pub async fn update_video_source(
let Some(active_model) = active_model else {
return Err(InnerApiError::NotFound(id).into());
};
active_model.save(db.as_ref()).await?;
active_model.save(&db).await?;
Ok(ApiResponse::ok(true))
}
/// 新增收藏夹订阅
pub async fn insert_favorite(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
ValidatedJson(request): ValidatedJson<InsertFavoriteRequest>,
) -> Result<ApiResponse<bool>, ApiError> {
@@ -199,14 +199,14 @@ pub async fn insert_favorite(
enabled: Set(true),
..Default::default()
})
.exec(db.as_ref())
.exec(&db)
.await?;
Ok(ApiResponse::ok(true))
}
/// 新增合集/列表订阅
pub async fn insert_collection(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
ValidatedJson(request): ValidatedJson<InsertCollectionRequest>,
) -> Result<ApiResponse<bool>, ApiError> {
@@ -228,7 +228,7 @@ pub async fn insert_collection(
enabled: Set(true),
..Default::default()
})
.exec(db.as_ref())
.exec(&db)
.await?;
Ok(ApiResponse::ok(true))
@@ -236,7 +236,7 @@ pub async fn insert_collection(
/// 新增投稿订阅
pub async fn insert_submission(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Extension(bili_client): Extension<Arc<BiliClient>>,
ValidatedJson(request): ValidatedJson<InsertSubmissionRequest>,
) -> Result<ApiResponse<bool>, ApiError> {
@@ -249,7 +249,7 @@ pub async fn insert_submission(
enabled: Set(true),
..Default::default()
})
.exec(db.as_ref())
.exec(&db)
.await?;
Ok(ApiResponse::ok(true))
}

View File

@@ -1,5 +1,4 @@
use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result;
use axum::extract::{Extension, Path, Query};
@@ -31,7 +30,7 @@ pub(super) fn router() -> Router {
/// 列出视频的基本信息,支持根据视频来源筛选、名称查找和分页
pub async fn get_videos(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Query(params): Query<VideosRequest>,
) -> Result<ApiResponse<VideosResponse>, ApiError> {
let mut query = video::Entity::find();
@@ -48,7 +47,7 @@ pub async fn get_videos(
if let Some(query_word) = params.query {
query = query.filter(video::Column::Name.contains(query_word));
}
let total_count = query.clone().count(db.as_ref()).await?;
let total_count = query.clone().count(&db).await?;
let (page, page_size) = if let (Some(page), Some(page_size)) = (params.page, params.page_size) {
(page, page_size)
} else {
@@ -58,7 +57,7 @@ pub async fn get_videos(
videos: query
.order_by_desc(video::Column::Id)
.into_partial_model::<VideoInfo>()
.paginate(db.as_ref(), page_size)
.paginate(&db, page_size)
.fetch_page(page)
.await?,
total_count,
@@ -67,17 +66,15 @@ pub async fn get_videos(
pub async fn get_video(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
) -> Result<ApiResponse<VideoResponse>, ApiError> {
let (video_info, pages_info) = tokio::try_join!(
video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref()),
video::Entity::find_by_id(id).into_partial_model::<VideoInfo>().one(&db),
page::Entity::find()
.filter(page::Column::VideoId.eq(id))
.order_by_asc(page::Column::Cid)
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.all(&db)
)?;
let Some(video_info) = video_info else {
return Err(InnerApiError::NotFound(id).into());
@@ -90,18 +87,16 @@ pub async fn get_video(
pub async fn reset_video(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Json(request): Json<ResetRequest>,
) -> Result<ApiResponse<ResetVideoResponse>, ApiError> {
let (video_info, pages_info) = tokio::try_join!(
video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref()),
video::Entity::find_by_id(id).into_partial_model::<VideoInfo>().one(&db),
page::Entity::find()
.filter(page::Column::VideoId.eq(id))
.order_by_asc(page::Column::Cid)
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.all(&db)
)?;
let Some(mut video_info) = video_info else {
return Err(InnerApiError::NotFound(id).into());
@@ -150,13 +145,13 @@ pub async fn reset_video(
}
pub async fn reset_all_videos(
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
Json(request): Json<ResetRequest>,
) -> Result<ApiResponse<ResetAllVideosResponse>, ApiError> {
// 先查询所有视频和页面数据
let (all_videos, all_pages) = tokio::try_join!(
video::Entity::find().into_partial_model::<VideoInfo>().all(db.as_ref()),
page::Entity::find().into_partial_model::<PageInfo>().all(db.as_ref())
video::Entity::find().into_partial_model::<VideoInfo>().all(&db),
page::Entity::find().into_partial_model::<PageInfo>().all(&db)
)?;
let resetted_pages_info = all_pages
.into_iter()
@@ -210,18 +205,16 @@ pub async fn reset_all_videos(
pub async fn update_video_status(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
Extension(db): Extension<DatabaseConnection>,
ValidatedJson(request): ValidatedJson<UpdateVideoStatusRequest>,
) -> Result<ApiResponse<UpdateVideoStatusResponse>, ApiError> {
let (video_info, mut pages_info) = tokio::try_join!(
video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref()),
video::Entity::find_by_id(id).into_partial_model::<VideoInfo>().one(&db),
page::Entity::find()
.filter(page::Column::VideoId.eq(id))
.order_by_asc(page::Column::Cid)
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.all(&db)
)?;
let Some(mut video_info) = video_info else {
return Err(InnerApiError::NotFound(id).into());

View File

@@ -2,7 +2,7 @@ use std::time::Duration;
use anyhow::{Context, Result};
use bili_sync_migration::{Migrator, MigratorTrait};
use sea_orm::sqlx::sqlite::SqliteConnectOptions;
use sea_orm::sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous};
use sea_orm::sqlx::{ConnectOptions as SqlxConnectOptions, Sqlite};
use sea_orm::{ConnectOptions, Database, DatabaseConnection, SqlxSqliteConnector};
@@ -15,7 +15,7 @@ fn database_url() -> String {
async fn database_connection() -> Result<DatabaseConnection> {
let mut option = ConnectOptions::new(database_url());
option
.max_connections(100)
.max_connections(50)
.min_connections(5)
.acquire_timeout(Duration::from_secs(90));
let connect_option = option
@@ -23,7 +23,10 @@ async fn database_connection() -> Result<DatabaseConnection> {
.parse::<SqliteConnectOptions>()
.context("Failed to parse database URL")?
.disable_statement_logging()
.busy_timeout(Duration::from_secs(90));
.busy_timeout(Duration::from_secs(90))
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.optimize_on_close(true, None);
Ok(SqlxSqliteConnector::from_sqlx_sqlite_pool(
option
.sqlx_pool_options::<Sqlite>()

View File

@@ -47,14 +47,14 @@ async fn main() {
if !cfg!(debug_assertions) {
spawn_task(
"定时下载",
video_downloader(connection, bili_client),
video_downloader(connection.clone(), bili_client),
&tracker,
token.clone(),
);
}
tracker.close();
handle_shutdown(tracker, token).await
handle_shutdown(connection, tracker, token).await
}
fn spawn_task(
@@ -77,7 +77,7 @@ fn spawn_task(
}
/// 初始化日志系统、打印欢迎信息,初始化数据库连接和全局配置
async fn init() -> (Arc<DatabaseConnection>, LogHelper) {
async fn init() -> (DatabaseConnection, LogHelper) {
let (tx, _rx) = tokio::sync::broadcast::channel(30);
let log_history = Arc::new(Mutex::new(VecDeque::with_capacity(MAX_HISTORY_LOGS + 1)));
let log_writer = LogHelper::new(tx, log_history.clone());
@@ -85,7 +85,7 @@ async fn init() -> (Arc<DatabaseConnection>, LogHelper) {
init_logger(&ARGS.log_level, Some(log_writer.clone()));
info!("欢迎使用 Bili-Sync当前程序版本{}", config::version());
info!("项目地址https://github.com/amtoaer/bili-sync");
let connection = Arc::new(setup_database().await.expect("数据库初始化失败"));
let connection = setup_database().await.expect("数据库初始化失败");
info!("数据库初始化完成");
VersionedConfig::init(&connection).await.expect("配置初始化失败");
info!("配置初始化完成");
@@ -93,16 +93,21 @@ async fn init() -> (Arc<DatabaseConnection>, LogHelper) {
(connection, log_writer)
}
async fn handle_shutdown(tracker: TaskTracker, token: CancellationToken) {
async fn handle_shutdown(connection: DatabaseConnection, tracker: TaskTracker, token: CancellationToken) {
tokio::select! {
_ = tracker.wait() => {
error!("所有任务均已终止,程序退出")
error!("所有任务均已终止..")
}
_ = terminate() => {
info!("接收到终止信号,正在终止任务..");
info!("接收到终止信号,开始终止任务..");
token.cancel();
tracker.wait().await;
info!("所有任务均已终止,程序退出");
info!("所有任务均已终止..");
}
}
info!("正在关闭数据库连接..");
match connection.close().await {
Ok(()) => info!("数据库连接已关闭,程序结束"),
Err(e) => error!("关闭数据库连接时遇到错误:{:#},程序异常结束", e),
}
}

View File

@@ -21,7 +21,7 @@ use crate::config::VersionedConfig;
struct Asset;
pub async fn http_server(
database_connection: Arc<DatabaseConnection>,
database_connection: DatabaseConnection,
bili_client: Arc<BiliClient>,
log_writer: LogHelper,
) -> Result<()> {

View File

@@ -11,7 +11,7 @@ use crate::utils::task_notifier::TASK_STATUS_NOTIFIER;
use crate::workflow::process_video_source;
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: Arc<DatabaseConnection>, bili_client: Arc<BiliClient>) {
pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc<BiliClient>) {
let mut anchor = chrono::Local::now().date_naive();
loop {
info!("开始执行本轮视频下载任务..");