feat: 重构视频下载任务的触发逻辑,由简单的 tokio::sleep 迁移至调度器调度 (#529)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-11-09 01:11:42 +08:00
committed by GitHub
parent c69a88f1da
commit 170bd14fe3
20 changed files with 544 additions and 152 deletions

View File

@@ -13,6 +13,7 @@ use crate::config::VersionedConfig;
mod config;
mod dashboard;
mod me;
mod task;
mod video_sources;
mod videos;
mod ws;
@@ -28,6 +29,7 @@ pub fn router() -> Router {
.merge(videos::router())
.merge(dashboard::router())
.merge(ws::router())
.merge(task::router())
.layer(middleware::from_fn(auth)),
)
}

View File

@@ -0,0 +1,15 @@
use anyhow::Result;
use axum::Router;
use axum::routing::post;
use crate::api::wrapper::{ApiError, ApiResponse};
use crate::task::DownloadTaskManager;
pub(super) fn router() -> Router {
Router::new().route("/task/download", post(new_download_task))
}
pub async fn new_download_task() -> Result<ApiResponse<bool>, ApiError> {
DownloadTaskManager::get().oneshot().await?;
Ok(ApiResponse::ok(true))
}

View File

@@ -26,7 +26,7 @@ use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::api::response::SysInfo;
use crate::utils::task_notifier::{TASK_STATUS_NOTIFIER, TaskStatus};
use crate::task::{DownloadTaskManager, TaskStatus};
static WEBSOCKET_HANDLER: LazyLock<WebSocketHandler> = LazyLock::new(WebSocketHandler::new);
@@ -209,7 +209,7 @@ impl WebSocketHandler {
let cancel_token = CancellationToken::new();
tokio::spawn(
async move {
let mut stream = WatchStream::new(TASK_STATUS_NOTIFIER.subscribe()).map(ServerEvent::Tasks);
let mut stream = WatchStream::new(DownloadTaskManager::get().subscribe()).map(ServerEvent::Tasks);
while let Some(event) = stream.next().await {
if let Err(e) = tx.send(event).await {
error!("Failed to send task status: {:?}", e);

View File

@@ -2,6 +2,7 @@ use std::path::PathBuf;
use std::sync::LazyLock;
use anyhow::{Result, bail};
use croner::parser::CronParser;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use validator::Validate;
@@ -9,7 +10,8 @@ use validator::Validate;
use crate::bilibili::{Credential, DanmakuOption, FilterOption};
use crate::config::default::{default_auth_token, default_bind_address, default_time_format};
use crate::config::item::{
ConcurrentLimit, NFOTimeType, SkipOption, default_collection_path, default_favorite_path, default_submission_path,
ConcurrentLimit, NFOTimeType, SkipOption, Trigger, default_collection_path, default_favorite_path,
default_submission_path,
};
use crate::notifier::Notifier;
use crate::utils::model::{load_db_config, save_db_config};
@@ -36,7 +38,7 @@ pub struct Config {
pub collection_default_path: String,
#[serde(default = "default_submission_path")]
pub submission_default_path: String,
pub interval: u64,
pub interval: Trigger,
pub upper_path: PathBuf,
pub nfo_time_type: NFOTimeType,
pub concurrent_limit: ConcurrentLimit,
@@ -77,6 +79,24 @@ impl Config {
if !(self.concurrent_limit.video > 0 && self.concurrent_limit.page > 0) {
errors.push("video 和 page 允许的并发数必须大于 0");
}
match &self.interval {
Trigger::Interval(secs) => {
if *secs <= 60 {
errors.push("下载任务执行间隔时间必须大于 60 秒");
}
}
Trigger::Cron(cron) => {
if CronParser::builder()
.seconds(croner::parser::Seconds::Required)
.dom_and_dow(true)
.build()
.parse(cron)
.is_err()
{
errors.push("Cron 表达式无效,正确格式为“秒 分 时 日 月 周”");
}
}
};
if !errors.is_empty() {
bail!(
errors
@@ -105,7 +125,7 @@ impl Default for Config {
favorite_default_path: default_favorite_path(),
collection_default_path: default_collection_path(),
submission_default_path: default_submission_path(),
interval: 1200,
interval: Trigger::default(),
upper_path: CONFIG_DIR.join("upper_face"),
nfo_time_type: NFOTimeType::FavTime,
concurrent_limit: ConcurrentLimit::default(),

View File

@@ -69,6 +69,19 @@ pub struct SkipOption {
pub no_subtitle: bool,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum Trigger {
Interval(u64),
Cron(String),
}
impl Default for Trigger {
fn default() -> Self {
Trigger::Interval(1200)
}
}
pub trait PathSafeTemplate {
fn path_safe_register(&mut self, name: &'static str, template: impl Into<String>) -> Result<()>;
fn path_safe_render(&self, name: &'static str, data: &serde_json::Value) -> Result<String>;

View File

@@ -9,6 +9,6 @@ mod versioned_config;
pub use crate::config::args::{ARGS, version};
pub use crate::config::current::{CONFIG_DIR, Config};
pub use crate::config::handlebar::TEMPLATE;
pub use crate::config::item::{ConcurrentDownloadLimit, NFOTimeType, PathSafeTemplate, RateLimit};
pub use crate::config::item::{ConcurrentDownloadLimit, NFOTimeType, PathSafeTemplate, RateLimit, Trigger};
pub use crate::config::versioned_cache::VersionedCache;
pub use crate::config::versioned_config::VersionedConfig;

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use anyhow::{Result, bail};
use arc_swap::{ArcSwap, Guard};
use sea_orm::DatabaseConnection;
use tokio::sync::{OnceCell, watch};
@@ -19,48 +19,48 @@ pub struct VersionedConfig {
impl VersionedConfig {
/// 初始化全局的 `VersionedConfig`,初始化失败或者已初始化过则返回错误
pub async fn init(connection: &DatabaseConnection) -> Result<()> {
let mut config = match Config::load_from_database(connection).await? {
Some(Ok(config)) => config,
Some(Err(e)) => bail!("解析数据库配置失败: {}", e),
None => {
if CONFIG_DIR.join("config.toml").exists() {
// 数据库中没有配置,但旧版配置文件存在,说明是从 2.6.0 之前的版本直接升级的
bail!(
"当前版本已移除配置文件的迁移逻辑,不再支持从配置文件加载配置。\n\
如果你正在运行 2.6.0 之前的版本,请先升级至 2.6.x 或 2.7.x\n\
启动时会自动将配置文件迁移至数据库,然后再升级至最新版本。"
);
}
let config = Config::default();
warn!(
"生成 auth_token{},可使用该 token 登录 web UI该信息仅在首次运行时打印",
config.auth_token
);
config.save_to_database(connection).await?;
config
}
};
// version 本身不具有实际意义,仅用于并发更新时的版本控制,在初始化时可以直接清空
config.version = 0;
let versioned_config = VersionedConfig::new(config);
pub async fn init(connection: &DatabaseConnection) -> Result<&'static VersionedConfig> {
VERSIONED_CONFIG
.set(versioned_config)
.map_err(|e| anyhow!("VERSIONED_CONFIG has already been initialized: {}", e))?;
Ok(())
.get_or_try_init(|| async move {
let mut config = match Config::load_from_database(connection).await? {
Some(Ok(config)) => config,
Some(Err(e)) => bail!("解析数据库配置失败: {}", e),
None => {
if CONFIG_DIR.join("config.toml").exists() {
// 数据库中没有配置,但旧版配置文件存在,说明是从 2.6.0 之前的版本直接升级的
bail!(
"当前版本已移除配置文件的迁移逻辑,不再支持从配置文件加载配置。\n\
如果你正在运行 2.6.0 之前的版本,请先升级至 2.6.x 或 2.7.x\n\
启动时会自动将配置文件迁移至数据库,然后再升级至最新版本。"
);
}
let config = Config::default();
warn!(
"生成 auth_token{},可使用该 token 登录 web UI该信息仅在首次运行时打印",
config.auth_token
);
config.save_to_database(connection).await?;
config
}
};
// version 本身不具有实际意义,仅用于并发更新时的版本控制,在初始化时可以直接清空
config.version = 0;
Ok(VersionedConfig::new(config))
})
.await
}
#[cfg(test)]
/// 仅在测试环境使用,该方法会尝试从测试数据库中加载配置并写入到全局的 VERSIONED_CONFIG
pub async fn init_for_test(connection: &DatabaseConnection) -> Result<()> {
let Some(Ok(config)) = Config::load_from_database(&connection).await? else {
bail!("no config found in test database");
};
let versioned_config = VersionedConfig::new(config);
pub async fn init_for_test(connection: &DatabaseConnection) -> Result<&'static VersionedConfig> {
VERSIONED_CONFIG
.set(versioned_config)
.map_err(|e| anyhow!("VERSIONED_CONFIG has already been initialized: {}", e))?;
Ok(())
.get_or_try_init(|| async move {
let Some(Ok(config)) = Config::load_from_database(&connection).await? else {
bail!("no config found in test database");
};
Ok(VersionedConfig::new(config))
})
.await
}
#[cfg(not(test))]

View File

@@ -45,14 +45,13 @@ async fn main() {
&tracker,
token.clone(),
);
if !cfg!(debug_assertions) {
spawn_task(
"定时下载",
video_downloader(connection.clone(), bili_client),
&tracker,
token.clone(),
);
}
spawn_task(
"定时下载",
video_downloader(connection.clone(), bili_client),
&tracker,
token.clone(),
);
tracker.close();
handle_shutdown(connection, tracker, token).await

View File

@@ -2,4 +2,4 @@ mod http_server;
mod video_downloader;
pub use http_server::http_server;
pub use video_downloader::video_downloader;
pub use video_downloader::{DownloadTaskManager, TaskStatus, video_downloader};

View File

@@ -1,45 +1,283 @@
use std::future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use chrono::NaiveDate;
use sea_orm::DatabaseConnection;
use tokio::time;
use serde::Serialize;
use tokio::sync::{OnceCell, watch};
use tokio_cron_scheduler::{Job, JobScheduler};
use crate::adapter::VideoSource;
use crate::bilibili::{self, BiliClient, BiliError};
use crate::config::{Config, TEMPLATE, VersionedConfig};
use crate::config::{Config, TEMPLATE, Trigger, VersionedConfig};
use crate::notifier::NotifierAllExt;
use crate::utils::model::get_enabled_video_sources;
use crate::utils::task_notifier::TASK_STATUS_NOTIFIER;
use crate::workflow::process_video_source;
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc<BiliClient>) {
let mut anchor = chrono::Local::now().date_naive();
loop {
let _lock = TASK_STATUS_NOTIFIER.start_running().await;
let mut config = VersionedConfig::get().snapshot();
info!("开始执行本轮视频下载任务..");
if let Err(e) = download_all_video_sources(&connection, &bili_client, &mut config, &mut anchor).await {
let error_msg = format!("本轮视频下载任务执行遇到错误:{:#}", e);
error!("{error_msg}");
let _ = config
.notifiers
.notify_all(bili_client.inner_client(), &error_msg)
.await;
} else {
info!("本轮视频下载任务执行完毕");
}
TASK_STATUS_NOTIFIER.finish_running(_lock, config.interval as i64);
time::sleep(time::Duration::from_secs(config.interval)).await;
static INSTANCE: OnceCell<DownloadTaskManager> = OnceCell::const_new();
pub struct DownloadTaskManager {
sched: Arc<JobScheduler>,
task_context: TaskContext,
}
#[derive(Serialize, Default, Clone, Copy, Debug)]
pub struct TaskStatus {
is_running: bool,
last_run: Option<chrono::DateTime<chrono::Local>>,
last_finish: Option<chrono::DateTime<chrono::Local>>,
next_run: Option<chrono::DateTime<chrono::Local>>,
}
#[derive(Clone)]
struct TaskContext {
connection: DatabaseConnection,
bili_client: Arc<BiliClient>,
running: Arc<tokio::sync::Mutex<()>>,
status_tx: watch::Sender<TaskStatus>,
status_rx: watch::Receiver<TaskStatus>,
updating: Arc<tokio::sync::Mutex<Option<uuid::Uuid>>>,
}
impl DownloadTaskManager {
pub async fn init(
connection: DatabaseConnection,
bili_client: Arc<BiliClient>,
) -> Result<&'static DownloadTaskManager> {
INSTANCE
.get_or_try_init(|| DownloadTaskManager::new(connection, bili_client))
.await
}
pub fn get() -> &'static DownloadTaskManager {
INSTANCE.get().expect("DownloadTaskManager is not initialized")
}
pub fn subscribe(&self) -> watch::Receiver<TaskStatus> {
self.task_context.status_rx.clone()
}
pub async fn oneshot(&self) -> Result<()> {
let task_context = self.task_context.clone();
let _ = self
.sched
.add(Job::new_one_shot_async(Duration::from_secs(0), move |uuid, l| {
DownloadTaskManager::download_video_task(uuid, l, task_context.clone())
})?)
.await?;
Ok(())
}
pub(self) async fn start(&self) -> Result<()> {
self.sched.start().await?;
Ok(())
}
async fn new(connection: DatabaseConnection, bili_client: Arc<BiliClient>) -> Result<Self> {
let sched = Arc::new(JobScheduler::new().await?);
let (status_tx, status_rx) = watch::channel(TaskStatus::default());
let (running, updating) = (
Arc::new(tokio::sync::Mutex::new(())),
Arc::new(tokio::sync::Mutex::new(None)),
);
// 固定每天凌晨 1 点更新凭据
let (connection_clone, bili_client_clone, running_clone) =
(connection.clone(), bili_client.clone(), running.clone());
sched
.add(Job::new_async_tz(
"0 0 1 * * *",
chrono::Local,
move |_uuid, mut _l| {
DownloadTaskManager::check_and_refresh_credential_task(
connection_clone.clone(),
bili_client_clone.clone(),
running_clone.clone(),
)
},
)?)
.await?;
let task_context = TaskContext {
connection: connection.clone(),
bili_client: bili_client.clone(),
running: running.clone(),
status_tx: status_tx.clone(),
status_rx: status_rx.clone(),
updating: updating.clone(),
};
// 根据 interval 策略分发不同触发机制的视频下载任务,并记录任务 ID
let mut rx = VersionedConfig::get().subscribe();
let initial_config = rx.borrow_and_update().clone();
let task_context_clone = task_context.clone();
let job_run = move |uuid, l| DownloadTaskManager::download_video_task(uuid, l, task_context_clone.clone());
let job = match &initial_config.interval {
Trigger::Interval(interval) => Job::new_repeated_async(Duration::from_secs(*interval), job_run)?,
Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?,
};
let download_task_id = sched.add(job).await?;
*updating.lock().await = Some(download_task_id);
// 发起一个一次性的任务,更新一下下次运行的时间
let task_context_clone = task_context.clone();
sched
.add(Job::new_one_shot_async(Duration::from_secs(0), move |_uuid, mut l| {
let task_context = task_context_clone.clone();
Box::pin(async move {
let old_status = *task_context.status_rx.borrow();
let next_run = l
.next_tick_for_job(download_task_id)
.await
.ok()
.flatten()
.map(|dt| dt.with_timezone(&chrono::Local));
let _ = task_context.status_tx.send(TaskStatus { next_run, ..old_status });
})
})?)
.await?;
// 监听配置变更,动态更新视频下载任务
let task_context_clone = task_context.clone();
let sched_clone = sched.clone();
tokio::spawn(async move {
while rx.changed().await.is_ok() {
let new_config = rx.borrow().clone();
let task_context = task_context_clone.clone();
// 先把旧的视频下载任务删掉
let mut task_id_guard = task_context_clone.updating.lock().await;
if let Some(old_task_id) = *task_id_guard {
sched_clone.remove(&old_task_id).await?;
}
// 再使用新的配置创建新的视频下载任务,并添加
let job_run = move |uuid, l| DownloadTaskManager::download_video_task(uuid, l, task_context.clone());
let job = match &new_config.interval {
Trigger::Interval(interval) => Job::new_repeated_async(Duration::from_secs(*interval), job_run)?,
Trigger::Cron(cron) => Job::new_async_tz(cron, chrono::Local, job_run)?,
};
let new_task_id = sched_clone.add(job).await?;
*task_id_guard = Some(new_task_id);
// 发起一个一次性的任务,更新一下下次运行的时间
let task_context = task_context_clone.clone();
sched_clone
.add(Job::new_one_shot_async(Duration::from_secs(0), move |_uuid, mut l| {
let task_context_clone = task_context.clone();
Box::pin(async move {
let old_status = *task_context_clone.status_rx.borrow();
let next_run = l
.next_tick_for_job(new_task_id)
.await
.ok()
.flatten()
.map(|dt| dt.with_timezone(&chrono::Local));
let _ = task_context_clone.status_tx.send(TaskStatus { next_run, ..old_status });
})
})?)
.await?;
}
Result::<(), anyhow::Error>::Ok(())
});
Ok(Self { sched, task_context })
}
fn check_and_refresh_credential_task(
connection: DatabaseConnection,
bili_client: Arc<BiliClient>,
running: Arc<tokio::sync::Mutex<()>>,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let _lock = running.lock().await;
let config = VersionedConfig::get().read();
info!("开始执行本轮凭据检查与刷新任务..");
match check_and_refresh_credential(connection, &bili_client, &config).await {
Ok(_) => info!("本轮凭据检查与刷新任务执行完毕"),
Err(e) => {
let error_msg = format!("本轮凭据检查与刷新任务执行遇到错误:{:#}", e);
error!("{error_msg}");
let _ = config
.notifiers
.notify_all(bili_client.inner_client(), &error_msg)
.await;
}
}
})
}
fn download_video_task(
current_task_uuid: uuid::Uuid,
mut l: JobScheduler,
cx: TaskContext,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let Ok(_lock) = cx.running.try_lock() else {
warn!("上一次视频下载任务尚未结束,跳过本次执行..");
return;
};
let _ = cx.status_tx.send(TaskStatus {
is_running: true,
last_run: Some(chrono::Local::now()),
last_finish: None,
next_run: None,
});
info!("开始执行本轮视频下载任务..");
let mut config = VersionedConfig::get().snapshot();
match download_all_video_sources(&cx.connection, &cx.bili_client, &mut config).await {
Ok(_) => info!("本轮视频下载任务执行完毕"),
Err(e) => {
let error_msg = format!("本轮视频下载任务执行遇到错误:{:#}", e);
error!("{error_msg}");
let _ = config
.notifiers
.notify_all(cx.bili_client.inner_client(), &error_msg)
.await;
}
}
// 注意此处尽量从 updating 中读取 uuid因为当前任务可能是不存在 next_tick 的 oneshot 任务
let task_uuid = (*cx.updating.lock().await).unwrap_or(current_task_uuid);
let next_run = l
.next_tick_for_job(task_uuid)
.await
.ok()
.flatten()
.map(|dt| dt.with_timezone(&chrono::Local));
let last_status = *cx.status_rx.borrow();
let _ = cx.status_tx.send(TaskStatus {
is_running: false,
last_run: last_status.last_run,
last_finish: Some(chrono::Local::now()),
next_run,
});
})
}
}
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: DatabaseConnection, bili_client: Arc<BiliClient>) -> Result<()> {
let task_manager = DownloadTaskManager::init(connection, bili_client).await?;
let _ = task_manager.start().await;
future::pending::<()>().await;
Ok(())
}
async fn check_and_refresh_credential(
connection: DatabaseConnection,
bili_client: &BiliClient,
config: &Config,
) -> Result<()> {
if let Some(new_credential) = bili_client
.check_refresh(&config.credential)
.await
.context("检查刷新 Credential 失败")?
{
VersionedConfig::get()
.update_credential(new_credential, &connection)
.await
.context("更新 Credential 失败")?;
}
Ok(())
}
async fn download_all_video_sources(
connection: &DatabaseConnection,
bili_client: &BiliClient,
config: &mut Arc<Config>,
anchor: &mut NaiveDate,
) -> Result<()> {
config.check().context("配置检查失败")?;
let mixin_key = bili_client
@@ -49,19 +287,6 @@ async fn download_all_video_sources(
.into_mixin_key()
.context("解析 mixin key 失败")?;
bilibili::set_global_mixin_key(mixin_key);
if *anchor != chrono::Local::now().date_naive() {
if let Some(new_credential) = bili_client
.check_refresh(&config.credential)
.await
.context("检查刷新 Credential 失败")?
{
*config = VersionedConfig::get()
.update_credential(new_credential, connection)
.await
.context("更新 Credential 失败")?;
}
*anchor = chrono::Local::now().date_naive();
}
let template = TEMPLATE.snapshot();
let bili_client = bili_client.snapshot()?;
let video_sources = get_enabled_video_sources(connection)

View File

@@ -7,7 +7,6 @@ pub mod nfo;
pub mod rule;
pub mod signal;
pub mod status;
pub mod task_notifier;
pub mod validation;
use tracing_subscriber::fmt;
use tracing_subscriber::layer::SubscriberExt;

View File

@@ -1,59 +0,0 @@
use std::sync::LazyLock;
use serde::Serialize;
use tokio::sync::{MutexGuard, watch};
pub static TASK_STATUS_NOTIFIER: LazyLock<TaskStatusNotifier> = LazyLock::new(TaskStatusNotifier::new);
#[derive(Serialize, Default, Clone, Copy)]
pub struct TaskStatus {
is_running: bool,
last_run: Option<chrono::DateTime<chrono::Local>>,
last_finish: Option<chrono::DateTime<chrono::Local>>,
next_run: Option<chrono::DateTime<chrono::Local>>,
}
pub struct TaskStatusNotifier {
mutex: tokio::sync::Mutex<()>,
tx: watch::Sender<TaskStatus>,
rx: watch::Receiver<TaskStatus>,
}
impl TaskStatusNotifier {
pub fn new() -> Self {
let (tx, rx) = watch::channel(TaskStatus::default());
Self {
mutex: tokio::sync::Mutex::const_new(()),
tx,
rx,
}
}
pub async fn start_running(&'_ self) -> MutexGuard<'_, ()> {
let lock = self.mutex.lock().await;
let _ = self.tx.send(TaskStatus {
is_running: true,
last_run: Some(chrono::Local::now()),
last_finish: None,
next_run: None,
});
lock
}
pub fn finish_running(&self, _lock: MutexGuard<()>, interval: i64) {
let last_status = self.tx.borrow();
let last_run = last_status.last_run;
drop(last_status);
let now = chrono::Local::now();
let _ = self.tx.send(TaskStatus {
is_running: false,
last_run,
last_finish: Some(now),
next_run: now.checked_add_signed(chrono::Duration::seconds(interval)),
});
}
pub fn subscribe(&self) -> tokio::sync::watch::Receiver<TaskStatus> {
self.rx.clone()
}
}