feat: 支持前端编辑、提交 Config (#370)
This commit is contained in:
@@ -13,7 +13,7 @@ pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Re
|
||||
if request.uri().path().starts_with("/api/")
|
||||
&& get_token(&headers).is_none_or(|token| token != VersionedConfig::get().load().auth_token)
|
||||
{
|
||||
return Ok(ApiResponse::unauthorized(()).into_response());
|
||||
return Ok(ApiResponse::<()>::unauthorized("auth token does not match").into_response());
|
||||
}
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ use crate::api::response::{
|
||||
};
|
||||
use crate::api::wrapper::{ApiError, ApiResponse, ValidatedJson};
|
||||
use crate::bilibili::{BiliClient, Collection, CollectionItem, FavoriteList, Me, Submission};
|
||||
use crate::config::{Config, VersionedConfig};
|
||||
use crate::task::DOWNLOADER_TASK_RUNNING;
|
||||
use crate::utils::status::{PageStatus, VideoStatus};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
@@ -66,6 +68,8 @@ pub fn api_router() -> Router {
|
||||
.route("/api/me/favorites", get(get_created_favorites))
|
||||
.route("/api/me/collections", get(get_followed_collections))
|
||||
.route("/api/me/uppers", get(get_followed_uppers))
|
||||
.route("/api/config", get(get_config))
|
||||
.route("/api/config", put(update_config))
|
||||
.route("/image-proxy", get(image_proxy))
|
||||
}
|
||||
|
||||
@@ -783,6 +787,39 @@ pub async fn update_video_source(
|
||||
Ok(ApiResponse::ok(true))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/config",
|
||||
responses(
|
||||
(status = 200, body = ApiResponse<Config>),
|
||||
)
|
||||
)]
|
||||
pub async fn get_config() -> Result<ApiResponse<Arc<Config>>, ApiError> {
|
||||
Ok(ApiResponse::ok(VersionedConfig::get().load_full()))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/config",
|
||||
request_body = Config,
|
||||
responses(
|
||||
(status = 200, body = ApiResponse<Config>),
|
||||
)
|
||||
)]
|
||||
pub async fn update_config(
|
||||
Extension(db): Extension<Arc<DatabaseConnection>>,
|
||||
ValidatedJson(config): ValidatedJson<Config>,
|
||||
) -> Result<ApiResponse<Arc<Config>>, ApiError> {
|
||||
let Ok(_lock) = DOWNLOADER_TASK_RUNNING.try_lock() else {
|
||||
// 简单避免一下可能的不一致现象
|
||||
return Err(InnerApiError::BadRequest("下载任务正在运行,无法修改配置".to_string()).into());
|
||||
};
|
||||
config.check()?;
|
||||
let new_config = VersionedConfig::get().update(config, db.as_ref()).await?;
|
||||
drop(_lock);
|
||||
Ok(ApiResponse::ok(new_config))
|
||||
}
|
||||
|
||||
/// B 站的图片会检查 referer,需要做个转发伪造一下,否则直接返回 403
|
||||
pub async fn image_proxy(
|
||||
Extension(bili_client): Extension<Arc<BiliClient>>,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::borrow::Cow;
|
||||
|
||||
use anyhow::Error;
|
||||
use axum::Json;
|
||||
use axum::extract::rejection::JsonRejection;
|
||||
@@ -14,28 +16,51 @@ use crate::api::error::InnerApiError;
|
||||
#[derive(ToSchema, Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
status_code: u16,
|
||||
data: T,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
data: Option<T>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
message: Option<Cow<'static, str>>,
|
||||
}
|
||||
|
||||
impl<T: Serialize> ApiResponse<T> {
|
||||
pub fn ok(data: T) -> Self {
|
||||
Self { status_code: 200, data }
|
||||
Self {
|
||||
status_code: 200,
|
||||
data: Some(data),
|
||||
message: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bad_request(data: T) -> Self {
|
||||
Self { status_code: 400, data }
|
||||
pub fn bad_request(message: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
status_code: 400,
|
||||
data: None,
|
||||
message: Some(message.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unauthorized(data: T) -> Self {
|
||||
Self { status_code: 401, data }
|
||||
pub fn unauthorized(message: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
status_code: 401,
|
||||
data: None,
|
||||
message: Some(message.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn not_found(data: T) -> Self {
|
||||
Self { status_code: 404, data }
|
||||
pub fn not_found(message: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
status_code: 404,
|
||||
data: None,
|
||||
message: Some(message.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn internal_server_error(data: T) -> Self {
|
||||
Self { status_code: 500, data }
|
||||
pub fn internal_server_error(message: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
status_code: 500,
|
||||
data: None,
|
||||
message: Some(message.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,13 +89,13 @@ impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
if let Some(inner_error) = self.0.downcast_ref::<InnerApiError>() {
|
||||
match inner_error {
|
||||
InnerApiError::NotFound(_) => return ApiResponse::not_found(self.0.to_string()).into_response(),
|
||||
InnerApiError::NotFound(_) => return ApiResponse::<()>::not_found(self.0.to_string()).into_response(),
|
||||
InnerApiError::BadRequest(_) => {
|
||||
return ApiResponse::bad_request(self.0.to_string()).into_response();
|
||||
return ApiResponse::<()>::bad_request(self.0.to_string()).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
ApiResponse::internal_server_error(self.0.to_string()).into_response()
|
||||
ApiResponse::<()>::internal_server_error(self.0.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::bilibili::error::BiliError;
|
||||
use crate::config::VersionedConfig;
|
||||
@@ -8,7 +9,7 @@ pub struct PageAnalyzer {
|
||||
info: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, strum::FromRepr, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[derive(Debug, strum::FromRepr, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, ToSchema, Clone)]
|
||||
pub enum VideoQuality {
|
||||
Quality360p = 16,
|
||||
Quality480p = 32,
|
||||
@@ -22,7 +23,7 @@ pub enum VideoQuality {
|
||||
Quality8k = 127,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, strum::FromRepr, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, strum::FromRepr, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||
pub enum AudioQuality {
|
||||
Quality64k = 30216,
|
||||
Quality132k = 30232,
|
||||
@@ -54,7 +55,18 @@ impl AudioQuality {
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, strum::EnumString, strum::Display, strum::AsRefStr, PartialEq, PartialOrd, Serialize, Deserialize)]
|
||||
#[derive(
|
||||
Debug,
|
||||
strum::EnumString,
|
||||
strum::Display,
|
||||
strum::AsRefStr,
|
||||
PartialEq,
|
||||
PartialOrd,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
ToSchema,
|
||||
Clone,
|
||||
)]
|
||||
pub enum VideoCodecs {
|
||||
#[strum(serialize = "hev")]
|
||||
HEV,
|
||||
@@ -79,7 +91,7 @@ impl TryFrom<u64> for VideoCodecs {
|
||||
}
|
||||
|
||||
// 视频流的筛选偏好
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone)]
|
||||
pub struct FilterOption {
|
||||
pub video_max_quality: VideoQuality,
|
||||
pub video_min_quality: VideoQuality,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
@@ -93,25 +92,25 @@ impl BiliClient {
|
||||
if let Some(limiter) = self.limiter.load().as_ref() {
|
||||
limiter.acquire_one().await;
|
||||
}
|
||||
let credential = VersionedConfig::get().load().credential.load();
|
||||
self.client.request(method, url, Some(credential.as_ref()))
|
||||
let credential = &VersionedConfig::get().load().credential;
|
||||
self.client.request(method, url, Some(credential))
|
||||
}
|
||||
|
||||
pub async fn check_refresh(&self, connection: &DatabaseConnection) -> Result<()> {
|
||||
let credential = VersionedConfig::get().load().credential.load();
|
||||
let credential = &VersionedConfig::get().load().credential;
|
||||
if !credential.need_refresh(&self.client).await? {
|
||||
return Ok(());
|
||||
}
|
||||
let new_credential = credential.refresh(&self.client).await?;
|
||||
let config = VersionedConfig::get().load();
|
||||
config.credential.store(Arc::new(new_credential));
|
||||
config.save_to_database(connection).await?;
|
||||
VersionedConfig::get()
|
||||
.update_credential(new_credential, connection)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取 wbi img,用于生成请求签名
|
||||
pub async fn wbi_img(&self) -> Result<WbiImg> {
|
||||
let credential = VersionedConfig::get().load().credential.load();
|
||||
let credential = &VersionedConfig::get().load().credential;
|
||||
credential.wbi_img(&self.client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use rsa::pkcs8::DecodePublicKey;
|
||||
use rsa::sha2::Sha256;
|
||||
use rsa::{Oaep, RsaPublicKey};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::bilibili::{Client, Validate};
|
||||
|
||||
@@ -19,7 +20,7 @@ const MIXIN_KEY_ENC_TAB: [usize; 64] = [
|
||||
20, 34, 44, 52,
|
||||
];
|
||||
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct Credential {
|
||||
pub sessdata: String,
|
||||
pub bili_jct: String,
|
||||
|
||||
@@ -4,13 +4,14 @@ mod lane;
|
||||
use anyhow::Result;
|
||||
use float_ord::FloatOrd;
|
||||
use lane::Lane;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::bilibili::PageInfo;
|
||||
use crate::bilibili::danmaku::canvas::lane::Collision;
|
||||
use crate::bilibili::danmaku::danmu::DanmuType;
|
||||
use crate::bilibili::danmaku::{Danmu, DrawEffect, Drawable};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, ToSchema)]
|
||||
pub struct DanmakuOption {
|
||||
pub duration: f64,
|
||||
pub font: String,
|
||||
|
||||
@@ -73,7 +73,7 @@ impl<'a> Me<'a> {
|
||||
}
|
||||
|
||||
fn my_id() -> String {
|
||||
VersionedConfig::get().load().credential.load().dedeuserid.clone()
|
||||
VersionedConfig::get().load().credential.dedeuserid.clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ use std::path::PathBuf;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use arc_swap::ArcSwap;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
use validator::Validate;
|
||||
|
||||
use crate::bilibili::{Credential, DanmakuOption, FilterOption};
|
||||
use crate::config::LegacyConfig;
|
||||
@@ -15,28 +16,23 @@ use crate::utils::model::{load_db_config, save_db_config};
|
||||
pub static CONFIG_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| dirs::config_dir().expect("No config path found").join("bili-sync"));
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, ToSchema, Validate, Clone)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_auth_token")]
|
||||
pub auth_token: String,
|
||||
#[serde(default = "default_bind_address")]
|
||||
pub bind_address: String,
|
||||
pub credential: ArcSwap<Credential>,
|
||||
pub credential: Credential,
|
||||
pub filter_option: FilterOption,
|
||||
#[serde(default)]
|
||||
pub danmaku_option: DanmakuOption,
|
||||
pub video_name: String,
|
||||
pub page_name: String,
|
||||
pub interval: u64,
|
||||
#[schema(value_type = String)]
|
||||
pub upper_path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub nfo_time_type: NFOTimeType,
|
||||
#[serde(default)]
|
||||
pub concurrent_limit: ConcurrentLimit,
|
||||
#[serde(default = "default_time_format")]
|
||||
pub time_format: String,
|
||||
#[serde(default)]
|
||||
pub cdn_sorting: bool,
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -59,7 +55,7 @@ impl Config {
|
||||
if self.page_name.is_empty() {
|
||||
errors.push("未设置 page_name 模板");
|
||||
}
|
||||
let credential = self.credential.load();
|
||||
let credential = &self.credential;
|
||||
if credential.sessdata.is_empty()
|
||||
|| credential.bili_jct.is_empty()
|
||||
|| credential.buvid3.is_empty()
|
||||
@@ -97,7 +93,7 @@ impl Default for Config {
|
||||
Self {
|
||||
auth_token: default_auth_token(),
|
||||
bind_address: default_bind_address(),
|
||||
credential: ArcSwap::from_pointee(Credential::default()),
|
||||
credential: Credential::default(),
|
||||
filter_option: FilterOption::default(),
|
||||
danmaku_option: DanmakuOption::default(),
|
||||
video_name: "{{title}}".to_owned(),
|
||||
@@ -108,6 +104,7 @@ impl Default for Config {
|
||||
concurrent_limit: ConcurrentLimit::default(),
|
||||
time_format: default_time_format(),
|
||||
cdn_sorting: false,
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,6 +125,7 @@ impl From<LegacyConfig> for Config {
|
||||
concurrent_limit: legacy.concurrent_limit,
|
||||
time_format: legacy.time_format,
|
||||
cdn_sorting: legacy.cdn_sorting,
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
pub static DOWNLOADER_RUNNING: AtomicBool = AtomicBool::new(false);
|
||||
@@ -13,6 +13,7 @@ fn create_template(config: &Config) -> Result<handlebars::Handlebars<'static>> {
|
||||
let mut handlebars = handlebars::Handlebars::new();
|
||||
handlebars.register_helper("truncate", Box::new(truncate));
|
||||
handlebars.path_safe_register("video", config.video_name.to_owned())?;
|
||||
handlebars.path_safe_register("page", config.page_name.to_owned())?;
|
||||
Ok(handlebars)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::utils::filenamify::filenamify;
|
||||
|
||||
@@ -13,7 +14,7 @@ pub struct WatchLaterConfig {
|
||||
}
|
||||
|
||||
/// NFO 文件使用的时间类型
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[derive(Serialize, Deserialize, Default, ToSchema, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NFOTimeType {
|
||||
#[default]
|
||||
@@ -22,7 +23,7 @@ pub enum NFOTimeType {
|
||||
}
|
||||
|
||||
/// 并发下载相关的配置
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone)]
|
||||
pub struct ConcurrentLimit {
|
||||
pub video: usize,
|
||||
pub page: usize,
|
||||
@@ -31,7 +32,7 @@ pub struct ConcurrentLimit {
|
||||
pub download: ConcurrentDownloadLimit,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone)]
|
||||
pub struct ConcurrentDownloadLimit {
|
||||
pub enable: bool,
|
||||
pub concurrency: usize,
|
||||
@@ -48,7 +49,7 @@ impl Default for ConcurrentDownloadLimit {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone)]
|
||||
pub struct RateLimit {
|
||||
pub limit: usize,
|
||||
pub duration: u64,
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::Result;
|
||||
use arc_swap::ArcSwap;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::de::{Deserializer, MapAccess, Visitor};
|
||||
use serde::ser::SerializeMap;
|
||||
@@ -20,7 +19,7 @@ pub struct LegacyConfig {
|
||||
pub auth_token: String,
|
||||
#[serde(default = "default_bind_address")]
|
||||
pub bind_address: String,
|
||||
pub credential: ArcSwap<Credential>,
|
||||
pub credential: Credential,
|
||||
pub filter_option: FilterOption,
|
||||
#[serde(default)]
|
||||
pub danmaku_option: DanmakuOption,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
mod args;
|
||||
mod current;
|
||||
mod default;
|
||||
mod flag;
|
||||
mod handlebar;
|
||||
mod item;
|
||||
mod legacy;
|
||||
@@ -10,7 +9,6 @@ mod versioned_config;
|
||||
|
||||
pub use crate::config::args::{ARGS, version};
|
||||
pub use crate::config::current::{CONFIG_DIR, Config};
|
||||
pub use crate::config::flag::DOWNLOADER_RUNNING;
|
||||
pub use crate::config::handlebar::TEMPLATE;
|
||||
pub use crate::config::item::{NFOTimeType, PathSafeTemplate, RateLimit};
|
||||
pub use crate::config::legacy::LegacyConfig;
|
||||
|
||||
@@ -10,16 +10,19 @@ pub struct VersionedCache<T> {
|
||||
inner: ArcSwap<T>,
|
||||
version: AtomicU64,
|
||||
builder: fn(&Config) -> Result<T>,
|
||||
mutex: parking_lot::Mutex<()>,
|
||||
}
|
||||
|
||||
impl<T> VersionedCache<T> {
|
||||
pub fn new(builder: fn(&Config) -> Result<T>) -> Result<Self> {
|
||||
let current_config = VersionedConfig::get().load();
|
||||
let current_version = current_config.version;
|
||||
let initial_value = builder(¤t_config)?;
|
||||
Ok(Self {
|
||||
inner: ArcSwap::from_pointee(initial_value),
|
||||
version: AtomicU64::new(0),
|
||||
version: AtomicU64::new(current_version),
|
||||
builder,
|
||||
mutex: parking_lot::Mutex::new(()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,21 +31,23 @@ impl<T> VersionedCache<T> {
|
||||
self.inner.load()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn load_full(&self) -> Arc<T> {
|
||||
self.reload_if_needed();
|
||||
self.inner.load_full()
|
||||
}
|
||||
|
||||
fn reload_if_needed(&self) {
|
||||
let current_version = VersionedConfig::get().version();
|
||||
let cached_version = self.version.load(Ordering::Acquire);
|
||||
|
||||
if current_version != cached_version {
|
||||
let current_config = VersionedConfig::get().load();
|
||||
if let Ok(new_value) = (self.builder)(¤t_config) {
|
||||
self.inner.store(Arc::new(new_value));
|
||||
self.version.store(current_version, Ordering::Release);
|
||||
let current_config = VersionedConfig::get().load();
|
||||
let current_version = current_config.version;
|
||||
let version = self.version.load(Ordering::Relaxed);
|
||||
if version < current_version {
|
||||
let _lock = self.mutex.lock();
|
||||
if self.version.load(Ordering::Relaxed) >= current_version {
|
||||
return;
|
||||
}
|
||||
match (self.builder)(¤t_config) {
|
||||
Err(e) => {
|
||||
error!("Failed to rebuild versioned cache: {:?}", e);
|
||||
}
|
||||
Ok(new_value) => {
|
||||
self.inner.store(Arc::new(new_value));
|
||||
self.version.store(current_version, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use arc_swap::{ArcSwap, Guard};
|
||||
use sea_orm::DatabaseConnection;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
use crate::bilibili::Credential;
|
||||
use crate::config::{CONFIG_DIR, Config, LegacyConfig};
|
||||
|
||||
pub static VERSIONED_CONFIG: OnceCell<VersionedConfig> = OnceCell::const_new();
|
||||
|
||||
pub struct VersionedConfig {
|
||||
inner: ArcSwap<Config>,
|
||||
version: AtomicU64,
|
||||
update_lock: tokio::sync::Mutex<()>,
|
||||
}
|
||||
|
||||
impl VersionedConfig {
|
||||
/// 初始化全局的 `VersionedConfig`,初始化失败或者已初始化过则返回错误
|
||||
pub async fn init(connection: &DatabaseConnection) -> Result<()> {
|
||||
let config = match Config::load_from_database(connection).await? {
|
||||
let mut config = match Config::load_from_database(connection).await? {
|
||||
Some(Ok(config)) => config,
|
||||
Some(Err(e)) => bail!("解析数据库配置失败: {}", e),
|
||||
None => {
|
||||
@@ -43,6 +43,8 @@ impl VersionedConfig {
|
||||
config
|
||||
}
|
||||
};
|
||||
// version 本身不具有实际意义,仅用于并发更新时的版本控制,在初始化时可以直接清空
|
||||
config.version = 0;
|
||||
let versioned_config = VersionedConfig::new(config);
|
||||
VERSIONED_CONFIG
|
||||
.set(versioned_config)
|
||||
@@ -67,7 +69,7 @@ impl VersionedConfig {
|
||||
pub fn new(config: Config) -> Self {
|
||||
Self {
|
||||
inner: ArcSwap::from_pointee(config),
|
||||
version: AtomicU64::new(1),
|
||||
update_lock: tokio::sync::Mutex::new(()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,13 +81,40 @@ impl VersionedConfig {
|
||||
self.inner.load_full()
|
||||
}
|
||||
|
||||
pub fn version(&self) -> u64 {
|
||||
self.version.load(Ordering::Acquire)
|
||||
pub async fn update_credential(&self, new_credential: Credential, connection: &DatabaseConnection) -> Result<()> {
|
||||
// 确保更新内容与写入数据库的操作是原子性的
|
||||
let _lock = self.update_lock.lock().await;
|
||||
loop {
|
||||
let old_config = self.inner.load();
|
||||
let mut new_config = old_config.as_ref().clone();
|
||||
new_config.credential = new_credential.clone();
|
||||
new_config.version += 1;
|
||||
if Arc::ptr_eq(
|
||||
&old_config,
|
||||
&self.inner.compare_and_swap(&old_config, Arc::new(new_config)),
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.inner.load().save_to_database(connection).await
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn update(&self, new_config: Config) {
|
||||
self.inner.store(Arc::new(new_config));
|
||||
self.version.fetch_add(1, Ordering::AcqRel);
|
||||
/// 外部 API 会调用这个方法,如果更新失败直接返回错误
|
||||
pub async fn update(&self, mut new_config: Config, connection: &DatabaseConnection) -> Result<Arc<Config>> {
|
||||
let _lock = self.update_lock.lock().await;
|
||||
let old_config = self.inner.load();
|
||||
if old_config.version != new_config.version {
|
||||
bail!("配置版本不匹配,请刷新页面修改后重新提交");
|
||||
}
|
||||
new_config.version += 1;
|
||||
let new_config = Arc::new(new_config);
|
||||
if !Arc::ptr_eq(
|
||||
&old_config,
|
||||
&self.inner.compare_and_swap(&old_config, new_config.clone()),
|
||||
) {
|
||||
bail!("配置版本不匹配,请刷新页面修改后重新提交");
|
||||
}
|
||||
new_config.save_to_database(connection).await?;
|
||||
Ok(new_config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,4 @@ mod http_server;
|
||||
mod video_downloader;
|
||||
|
||||
pub use http_server::http_server;
|
||||
pub use video_downloader::video_downloader;
|
||||
pub use video_downloader::{DOWNLOADER_TASK_RUNNING, video_downloader};
|
||||
|
||||
@@ -4,16 +4,18 @@ use sea_orm::DatabaseConnection;
|
||||
use tokio::time;
|
||||
|
||||
use crate::bilibili::{self, BiliClient};
|
||||
use crate::config::{DOWNLOADER_RUNNING, VersionedConfig};
|
||||
use crate::config::VersionedConfig;
|
||||
use crate::utils::model::get_enabled_video_sources;
|
||||
use crate::workflow::process_video_source;
|
||||
|
||||
pub static DOWNLOADER_TASK_RUNNING: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
|
||||
|
||||
/// 启动周期下载视频的任务
|
||||
pub async fn video_downloader(connection: Arc<DatabaseConnection>, bili_client: Arc<BiliClient>) {
|
||||
let mut anchor = chrono::Local::now().date_naive();
|
||||
loop {
|
||||
info!("开始执行本轮视频下载任务..");
|
||||
DOWNLOADER_RUNNING.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
let _lock = DOWNLOADER_TASK_RUNNING.lock().await;
|
||||
let config = VersionedConfig::get().load_full();
|
||||
'inner: {
|
||||
if let Err(e) = config.check() {
|
||||
@@ -53,7 +55,7 @@ pub async fn video_downloader(connection: Arc<DatabaseConnection>, bili_client:
|
||||
}
|
||||
info!("本轮任务执行完毕,等待下一轮执行");
|
||||
}
|
||||
DOWNLOADER_RUNNING.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
drop(_lock);
|
||||
time::sleep(time::Duration::from_secs(config.interval)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user