diff --git a/crates/bili_sync/src/api/routes/task/mod.rs b/crates/bili_sync/src/api/routes/task/mod.rs index a513116..3cd3705 100644 --- a/crates/bili_sync/src/api/routes/task/mod.rs +++ b/crates/bili_sync/src/api/routes/task/mod.rs @@ -10,6 +10,6 @@ pub(super) fn router() -> Router { } pub async fn new_download_task() -> Result, ApiError> { - DownloadTaskManager::get().oneshot().await?; + DownloadTaskManager::get().download_once().await?; Ok(ApiResponse::ok(true)) } diff --git a/crates/bili_sync/src/config/current.rs b/crates/bili_sync/src/config/current.rs index adbcf35..226d8f3 100644 --- a/crates/bili_sync/src/config/current.rs +++ b/crates/bili_sync/src/config/current.rs @@ -1,5 +1,5 @@ use std::path::PathBuf; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use anyhow::{Result, bail}; use croner::parser::CronParser; @@ -31,7 +31,7 @@ pub struct Config { pub video_name: String, pub page_name: String, #[serde(default)] - pub notifiers: Option>, + pub notifiers: Option>>, #[serde(default = "default_favorite_path")] pub favorite_default_path: String, #[serde(default = "default_collection_path")] diff --git a/crates/bili_sync/src/notifier/mod.rs b/crates/bili_sync/src/notifier/mod.rs index 0dd1578..8e58ee7 100644 --- a/crates/bili_sync/src/notifier/mod.rs +++ b/crates/bili_sync/src/notifier/mod.rs @@ -18,12 +18,9 @@ pub trait NotifierAllExt { async fn notify_all(&self, client: &reqwest::Client, message: &str) -> Result<()>; } -impl NotifierAllExt for Option> { +impl NotifierAllExt for Vec { async fn notify_all(&self, client: &reqwest::Client, message: &str) -> Result<()> { - let Some(notifiers) = self else { - return Ok(()); - }; - future::join_all(notifiers.iter().map(|notifier| notifier.notify(client, message))).await; + future::join_all(self.iter().map(|notifier| notifier.notify(client, message))).await; Ok(()) } } diff --git a/crates/bili_sync/src/task/video_downloader.rs b/crates/bili_sync/src/task/video_downloader.rs index 8e21dd7..de71a7e 100644 --- a/crates/bili_sync/src/task/video_downloader.rs +++ b/crates/bili_sync/src/task/video_downloader.rs @@ -1,4 +1,3 @@ -use std::future; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -12,15 +11,22 @@ use tokio_cron_scheduler::{Job, JobScheduler}; use crate::adapter::VideoSource; use crate::bilibili::{self, BiliClient, BiliError}; use crate::config::{Config, TEMPLATE, Trigger, VersionedConfig}; -use crate::notifier::NotifierAllExt; use crate::utils::model::get_enabled_video_sources; +use crate::utils::notify::error_and_notify; use crate::workflow::process_video_source; static INSTANCE: OnceCell = OnceCell::const_new(); +/// 启动周期下载视频的任务 +pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc) -> Result<()> { + let task_manager = DownloadTaskManager::init(connection, bili_client).await?; + task_manager.start().await +} + pub struct DownloadTaskManager { - sched: Arc, - task_context: TaskContext, + sched: Arc>, + cx: Arc, + shutdown_rx: watch::Receiver>, } #[derive(Serialize, Default, Clone, Copy, Debug)] @@ -31,17 +37,17 @@ pub struct TaskStatus { next_run: Option>, } -#[derive(Clone)] struct TaskContext { connection: DatabaseConnection, bili_client: Arc, - running: Arc>, + running: tokio::sync::Mutex<()>, status_tx: watch::Sender, status_rx: watch::Receiver, - updating: Arc>>, + video_task_id: tokio::sync::Mutex>, // 存储当前视频下载任务的 UUID } impl DownloadTaskManager { + /// 初始化 DownloadTaskManager 单例 pub async fn init( connection: DatabaseConnection, bili_client: Arc, @@ -51,213 +57,255 @@ impl DownloadTaskManager { .await } + /// 获取 DownloadTaskManager 单例,未初始化时直接 panic pub fn get() -> &'static DownloadTaskManager { INSTANCE.get().expect("DownloadTaskManager is not initialized") } + /// 订阅下载任务的状态更新 pub fn subscribe(&self) -> watch::Receiver { - self.task_context.status_rx.clone() + self.cx.status_rx.clone() } - pub async fn oneshot(&self) -> Result<()> { - let task_context = self.task_context.clone(); + /// 手动执行一次下载任务 + pub async fn download_once(&self) -> Result<()> { let _ = self .sched - .add(Job::new_one_shot_async(Duration::from_secs(0), move |uuid, l| { - DownloadTaskManager::download_video_task(uuid, l, task_context.clone()) - })?) + .lock() + .await + .add(Job::new_one_shot_async( + Duration::from_secs(0), + DownloadTaskManager::download_video_task(self.cx.clone()), + )?) .await?; Ok(()) } - pub(self) async fn start(&self) -> Result<()> { - self.sched.start().await?; + /// 启动任务调度器 + async fn start(&self) -> Result<()> { + self.sched.lock().await.start().await?; + let mut shutdown_rx = self.shutdown_rx.clone(); + shutdown_rx.changed().await?; + self.sched.lock().await.shutdown().await.context("任务调度器关闭失败")?; + if let Err(e) = &*shutdown_rx.borrow() { + bail!("{:#}", e); + } Ok(()) } + /// 私有的调度器构造函数 async fn new(connection: DatabaseConnection, bili_client: Arc) -> Result { - let sched = Arc::new(JobScheduler::new().await?); + let sched = Arc::new(tokio::sync::Mutex::new(JobScheduler::new().await?)); let (status_tx, status_rx) = watch::channel(TaskStatus::default()); - let (running, updating) = ( - Arc::new(tokio::sync::Mutex::new(())), - Arc::new(tokio::sync::Mutex::new(None)), - ); - // 固定每天凌晨 1 点更新凭据 - let (connection_clone, bili_client_clone, running_clone) = - (connection.clone(), bili_client.clone(), running.clone()); + let (running, video_task_id) = (tokio::sync::Mutex::new(()), tokio::sync::Mutex::new(None)); + let cx = Arc::new(TaskContext { + connection, + bili_client, + running, + status_tx, + status_rx, + video_task_id, + }); + // 读取初始配置 + let mut rx = VersionedConfig::get().subscribe(); + let initial_config = rx.borrow_and_update().clone(); + // 初始化凭据检查与刷新任务,该任务必须成功,否则直接退出 sched + .lock() + .await .add(Job::new_async_tz( "0 0 1 * * *", chrono::Local, - move |_uuid, mut _l| { - DownloadTaskManager::check_and_refresh_credential_task( - connection_clone.clone(), - bili_client_clone.clone(), - running_clone.clone(), - ) - }, + DownloadTaskManager::check_and_refresh_credential_task(cx.clone()), )?) .await?; - let task_context = TaskContext { - connection: connection.clone(), - bili_client: bili_client.clone(), - running: running.clone(), - status_tx: status_tx.clone(), - status_rx: status_rx.clone(), - updating: updating.clone(), - }; - // 根据 interval 策略分发不同触发机制的视频下载任务,并记录任务 ID - let mut rx = VersionedConfig::get().subscribe(); - let initial_config = rx.borrow_and_update().clone(); - let task_context_clone = task_context.clone(); - let job_run = move |uuid, l| DownloadTaskManager::download_video_task(uuid, l, task_context_clone.clone()); - let job = match &initial_config.interval { - Trigger::Interval(interval) => Job::new_repeated_async(Duration::from_secs(*interval), job_run)?, - Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?, - }; - let download_task_id = sched.add(job).await?; - *updating.lock().await = Some(download_task_id); - // 发起一个一次性的任务,更新一下下次运行的时间 - let task_context_clone = task_context.clone(); - sched - .add(Job::new_one_shot_async(Duration::from_secs(0), move |_uuid, mut l| { - let task_context = task_context_clone.clone(); - Box::pin(async move { - let old_status = *task_context.status_rx.borrow(); - let next_run = l - .next_tick_for_job(download_task_id) - .await - .ok() - .flatten() - .map(|dt| dt.with_timezone(&chrono::Local)); - let _ = task_context.status_tx.send(TaskStatus { next_run, ..old_status }); - }) - })?) - .await?; - // 监听配置变更,动态更新视频下载任务 - let task_context_clone = task_context.clone(); - let sched_clone = sched.clone(); - tokio::spawn(async move { - while rx.changed().await.is_ok() { - let new_config = rx.borrow().clone(); - let task_context = task_context_clone.clone(); - // 先把旧的视频下载任务删掉 - let mut task_id_guard = task_context_clone.updating.lock().await; - if let Some(old_task_id) = *task_id_guard { - sched_clone.remove(&old_task_id).await?; - } - // 再使用新的配置创建新的视频下载任务,并添加 - let job_run = move |uuid, l| DownloadTaskManager::download_video_task(uuid, l, task_context.clone()); - let job = match &new_config.interval { - Trigger::Interval(interval) => Job::new_repeated_async(Duration::from_secs(*interval), job_run)?, - Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?, - }; - let new_task_id = sched_clone.add(job).await?; - *task_id_guard = Some(new_task_id); - // 发起一个一次性的任务,更新一下下次运行的时间 - let task_context = task_context_clone.clone(); - sched_clone - .add(Job::new_one_shot_async(Duration::from_secs(0), move |_uuid, mut l| { - let task_context_clone = task_context.clone(); - Box::pin(async move { - let old_status = *task_context_clone.status_rx.borrow(); - let next_run = l - .next_tick_for_job(new_task_id) - .await - .ok() - .flatten() - .map(|dt| dt.with_timezone(&chrono::Local)); - let _ = task_context_clone.status_tx.send(TaskStatus { next_run, ..old_status }); - }) - })?) - .await?; + // 初始化并添加视频下载任务,将任务 ID 保存到 TaskManager 中 + let video_task_id = async { + let job_run = DownloadTaskManager::download_video_task(cx.clone()); + let job = match &initial_config.interval { + Trigger::Interval(interval) => Job::new_repeated_async(Duration::from_secs(*interval), job_run)?, + Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?, + }; + Result::<_, anyhow::Error>::Ok(sched.lock().await.add(job).await?) + } + .await; + let video_task_id = match video_task_id { + Ok(id) => Some(id), + Err(err) => { + error_and_notify( + &initial_config, + &cx.bili_client, + format!("初始化视频下载任务失败:{:#}", err), + ); + None } - Result::<(), anyhow::Error>::Ok(()) + }; + *cx.video_task_id.lock().await = video_task_id; + // 发起一个一次性的任务,更新一下下次运行的时间 + if let Some(video_task_id) = video_task_id { + sched + .lock() + .await + .add(Job::new_one_shot_async( + Duration::from_secs(0), + DownloadTaskManager::refresh_next_run(video_task_id, cx.clone()), + )?) + .await?; + } + // 发起一个新任务,用来监听配置变更,动态更新视频下载任务 + let cx_clone = cx.clone(); + let sched_clone = sched.clone(); + let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(Ok(())); + tokio::spawn(async move { + let update_task_result = async { + while rx.changed().await.is_ok() { + let new_config = rx.borrow().clone(); + let cx = cx_clone.clone(); + let mut video_task_id = cx.video_task_id.lock().await; + if let Some(old_video_task_id) = *video_task_id { + // 这里必须成功,不然后面会重复添加任务 + sched_clone + .lock() + .await + .remove(&old_video_task_id) + .await + .context("移除旧的视频下载任务失败")?; + } + let new_video_task_id = async { + let job_run = DownloadTaskManager::download_video_task(cx.clone()); + let job = match &new_config.interval { + Trigger::Interval(interval) => { + Job::new_repeated_async(Duration::from_secs(*interval), job_run)? + } + Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?, + }; + Result::<_, anyhow::Error>::Ok(sched_clone.lock().await.add(job).await?) + } + .await; + let new_video_task_id = match new_video_task_id { + Ok(id) => Some(id), + Err(err) => { + error_and_notify( + &initial_config, + &cx.bili_client, + format!("重载视频下载任务失败:{:#}", err), + ); + None + } + }; + *video_task_id = new_video_task_id; + if let Some(video_task_id) = new_video_task_id { + sched_clone + .lock() + .await + .add(Job::new_one_shot_async( + Duration::from_secs(0), + DownloadTaskManager::refresh_next_run(video_task_id, cx.clone()), + )?) + .await?; + } + } + Result::<(), anyhow::Error>::Ok(()) + } + .await; + // 如果执行正常,上面应该是永远不会退出的 + let _ = shutdown_tx.send(update_task_result); }); - Ok(Self { sched, task_context }) + Ok(Self { sched, cx, shutdown_rx }) } fn check_and_refresh_credential_task( - connection: DatabaseConnection, - bili_client: Arc, - running: Arc>, - ) -> Pin + Send>> { - Box::pin(async move { - let _lock = running.lock().await; - let config = VersionedConfig::get().read(); - info!("开始执行本轮凭据检查与刷新任务.."); - match check_and_refresh_credential(connection, &bili_client, &config).await { - Ok(_) => info!("本轮凭据检查与刷新任务执行完毕"), - Err(e) => { - let error_msg = format!("本轮凭据检查与刷新任务执行遇到错误:{:#}", e); - error!("{error_msg}"); - let _ = config - .notifiers - .notify_all(bili_client.inner_client(), &error_msg) - .await; + cx: Arc, + ) -> impl FnMut(uuid::Uuid, JobScheduler) -> Pin + Send>> { + move |_uuid, _l| { + let cx = cx.clone(); + Box::pin(async move { + let _lock = cx.running.lock().await; + let config = VersionedConfig::get().read(); + info!("开始执行本轮凭据检查与刷新任务.."); + match check_and_refresh_credential(&cx.connection, &cx.bili_client, &config).await { + Ok(_) => info!("本轮凭据检查与刷新任务执行完毕"), + Err(e) => { + error_and_notify( + &config, + &cx.bili_client, + format!("本轮凭据检查与刷新任务执行遇到错误:{:#}", e), + ); + } } - } - }) + }) + } + } + + fn refresh_next_run( + video_task_id: uuid::Uuid, + cx: Arc, + ) -> impl FnMut(uuid::Uuid, JobScheduler) -> Pin + Send>> { + move |_uuid, mut l| { + let cx = cx.clone(); + Box::pin(async move { + let old_status = *cx.status_rx.borrow(); + let next_run = l + .next_tick_for_job(video_task_id) + .await + .ok() + .flatten() + .map(|dt| dt.with_timezone(&chrono::Local)); + let _ = cx.status_tx.send(TaskStatus { next_run, ..old_status }); + }) + } } fn download_video_task( - current_task_uuid: uuid::Uuid, - mut l: JobScheduler, - cx: TaskContext, - ) -> Pin + Send>> { - Box::pin(async move { - let Ok(_lock) = cx.running.try_lock() else { - warn!("上一次视频下载任务尚未结束,跳过本次执行.."); - return; - }; - let _ = cx.status_tx.send(TaskStatus { - is_running: true, - last_run: Some(chrono::Local::now()), - last_finish: None, - next_run: None, - }); - info!("开始执行本轮视频下载任务.."); - let mut config = VersionedConfig::get().snapshot(); - match download_all_video_sources(&cx.connection, &cx.bili_client, &mut config).await { - Ok(_) => info!("本轮视频下载任务执行完毕"), - Err(e) => { - let error_msg = format!("本轮视频下载任务执行遇到错误:{:#}", e); - error!("{error_msg}"); - let _ = config - .notifiers - .notify_all(cx.bili_client.inner_client(), &error_msg) - .await; + cx: Arc, + ) -> impl FnMut(uuid::Uuid, JobScheduler) -> Pin + Send>> { + move |uuid, mut l| { + let cx = cx.clone(); + Box::pin(async move { + let Ok(_lock) = cx.running.try_lock() else { + warn!("上一次视频下载任务尚未结束,跳过本次执行.."); + return; + }; + let _ = cx.status_tx.send(TaskStatus { + is_running: true, + last_run: Some(chrono::Local::now()), + last_finish: None, + next_run: None, + }); + info!("开始执行本轮视频下载任务.."); + let mut config = VersionedConfig::get().snapshot(); + match download_video(&cx.connection, &cx.bili_client, &mut config).await { + Ok(_) => info!("本轮视频下载任务执行完毕"), + Err(e) => { + error_and_notify( + &config, + &cx.bili_client, + format!("本轮视频下载任务执行遇到错误:{:#}", e), + ); + } } - } - // 注意此处尽量从 updating 中读取 uuid,因为当前任务可能是不存在 next_tick 的 oneshot 任务 - let task_uuid = (*cx.updating.lock().await).unwrap_or(current_task_uuid); - let next_run = l - .next_tick_for_job(task_uuid) - .await - .ok() - .flatten() - .map(|dt| dt.with_timezone(&chrono::Local)); - let last_status = *cx.status_rx.borrow(); - let _ = cx.status_tx.send(TaskStatus { - is_running: false, - last_run: last_status.last_run, - last_finish: Some(chrono::Local::now()), - next_run, - }); - }) + // 注意此处尽量从 updating 中读取 uuid,因为当前任务可能是不存在 next_tick 的 oneshot 任务 + let task_uuid = (*cx.video_task_id.lock().await).unwrap_or(uuid); + let next_run = l + .next_tick_for_job(task_uuid) + .await + .ok() + .flatten() + .map(|dt| dt.with_timezone(&chrono::Local)); + let last_status = *cx.status_rx.borrow(); + let _ = cx.status_tx.send(TaskStatus { + is_running: false, + last_run: last_status.last_run, + last_finish: Some(chrono::Local::now()), + next_run, + }); + }) + } } } -/// 启动周期下载视频的任务 -pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc) -> Result<()> { - let task_manager = DownloadTaskManager::init(connection, bili_client).await?; - let _ = task_manager.start().await; - future::pending::<()>().await; - Ok(()) -} - async fn check_and_refresh_credential( - connection: DatabaseConnection, + connection: &DatabaseConnection, bili_client: &BiliClient, config: &Config, ) -> Result<()> { @@ -267,14 +315,14 @@ async fn check_and_refresh_credential( .context("检查刷新 Credential 失败")? { VersionedConfig::get() - .update_credential(new_credential, &connection) + .update_credential(new_credential, connection) .await .context("更新 Credential 失败")?; } Ok(()) } -async fn download_all_video_sources( +async fn download_video( connection: &DatabaseConnection, bili_client: &BiliClient, config: &mut Arc, @@ -298,12 +346,11 @@ async fn download_all_video_sources( for video_source in video_sources { let display_name = video_source.display_name(); if let Err(e) = process_video_source(video_source, &bili_client, connection, &template, config).await { - let error_msg = format!("处理 {} 时遇到错误:{:#},跳过该视频源", display_name, e); - error!("{error_msg}"); - let _ = config - .notifiers - .notify_all(bili_client.inner_client(), &error_msg) - .await; + error_and_notify( + config, + &bili_client, + format!("处理 {} 时遇到错误:{:#},跳过该视频源", display_name, e), + ); if let Ok(e) = e.downcast::() && e.is_risk_control_related() { diff --git a/crates/bili_sync/src/utils/mod.rs b/crates/bili_sync/src/utils/mod.rs index e60542f..a308d29 100644 --- a/crates/bili_sync/src/utils/mod.rs +++ b/crates/bili_sync/src/utils/mod.rs @@ -4,6 +4,7 @@ pub mod filenamify; pub mod format_arg; pub mod model; pub mod nfo; +pub mod notify; pub mod rule; pub mod signal; pub mod status; diff --git a/crates/bili_sync/src/utils/notify.rs b/crates/bili_sync/src/utils/notify.rs new file mode 100644 index 0000000..128ef91 --- /dev/null +++ b/crates/bili_sync/src/utils/notify.rs @@ -0,0 +1,13 @@ +use crate::bilibili::BiliClient; +use crate::config::Config; +use crate::notifier::NotifierAllExt; + +pub fn error_and_notify(config: &Config, bili_client: &BiliClient, msg: String) { + error!("{msg}"); + if let Some(notifiers) = &config.notifiers + && !notifiers.is_empty() + { + let (notifiers, inner_client) = (notifiers.clone(), bili_client.inner_client().clone()); + tokio::spawn(async move { notifiers.notify_all(&inner_client, msg.as_str()).await }); + } +}