feat: 支持前端编辑、提交 Config (#370)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-06-18 16:50:16 +08:00
committed by GitHub
parent 28971c3ff3
commit e50318870e
27 changed files with 963 additions and 311 deletions

View File

@@ -13,7 +13,7 @@ pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Re
if request.uri().path().starts_with("/api/")
&& get_token(&headers).is_none_or(|token| token != VersionedConfig::get().load().auth_token)
{
return Ok(ApiResponse::unauthorized(()).into_response());
return Ok(ApiResponse::<()>::unauthorized("auth token does not match").into_response());
}
Ok(next.run(request).await)
}

View File

@@ -34,6 +34,8 @@ use crate::api::response::{
};
use crate::api::wrapper::{ApiError, ApiResponse, ValidatedJson};
use crate::bilibili::{BiliClient, Collection, CollectionItem, FavoriteList, Me, Submission};
use crate::config::{Config, VersionedConfig};
use crate::task::DOWNLOADER_TASK_RUNNING;
use crate::utils::status::{PageStatus, VideoStatus};
#[derive(OpenApi)]
@@ -66,6 +68,8 @@ pub fn api_router() -> Router {
.route("/api/me/favorites", get(get_created_favorites))
.route("/api/me/collections", get(get_followed_collections))
.route("/api/me/uppers", get(get_followed_uppers))
.route("/api/config", get(get_config))
.route("/api/config", put(update_config))
.route("/image-proxy", get(image_proxy))
}
@@ -783,6 +787,39 @@ pub async fn update_video_source(
Ok(ApiResponse::ok(true))
}
#[utoipa::path(
get,
path = "/api/config",
responses(
(status = 200, body = ApiResponse<Config>),
)
)]
pub async fn get_config() -> Result<ApiResponse<Arc<Config>>, ApiError> {
Ok(ApiResponse::ok(VersionedConfig::get().load_full()))
}
#[utoipa::path(
put,
path = "/api/config",
request_body = Config,
responses(
(status = 200, body = ApiResponse<Config>),
)
)]
pub async fn update_config(
Extension(db): Extension<Arc<DatabaseConnection>>,
ValidatedJson(config): ValidatedJson<Config>,
) -> Result<ApiResponse<Arc<Config>>, ApiError> {
let Ok(_lock) = DOWNLOADER_TASK_RUNNING.try_lock() else {
// 简单避免一下可能的不一致现象
return Err(InnerApiError::BadRequest("下载任务正在运行,无法修改配置".to_string()).into());
};
config.check()?;
let new_config = VersionedConfig::get().update(config, db.as_ref()).await?;
drop(_lock);
Ok(ApiResponse::ok(new_config))
}
/// B 站的图片会检查 referer需要做个转发伪造一下否则直接返回 403
pub async fn image_proxy(
Extension(bili_client): Extension<Arc<BiliClient>>,

View File

@@ -1,3 +1,5 @@
use std::borrow::Cow;
use anyhow::Error;
use axum::Json;
use axum::extract::rejection::JsonRejection;
@@ -14,28 +16,51 @@ use crate::api::error::InnerApiError;
#[derive(ToSchema, Serialize)]
pub struct ApiResponse<T: Serialize> {
status_code: u16,
data: T,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<T>,
#[serde(skip_serializing_if = "Option::is_none")]
message: Option<Cow<'static, str>>,
}
impl<T: Serialize> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self { status_code: 200, data }
Self {
status_code: 200,
data: Some(data),
message: None,
}
}
pub fn bad_request(data: T) -> Self {
Self { status_code: 400, data }
pub fn bad_request(message: impl Into<Cow<'static, str>>) -> Self {
Self {
status_code: 400,
data: None,
message: Some(message.into()),
}
}
pub fn unauthorized(data: T) -> Self {
Self { status_code: 401, data }
pub fn unauthorized(message: impl Into<Cow<'static, str>>) -> Self {
Self {
status_code: 401,
data: None,
message: Some(message.into()),
}
}
pub fn not_found(data: T) -> Self {
Self { status_code: 404, data }
pub fn not_found(message: impl Into<Cow<'static, str>>) -> Self {
Self {
status_code: 404,
data: None,
message: Some(message.into()),
}
}
pub fn internal_server_error(data: T) -> Self {
Self { status_code: 500, data }
pub fn internal_server_error(message: impl Into<Cow<'static, str>>) -> Self {
Self {
status_code: 500,
data: None,
message: Some(message.into()),
}
}
}
@@ -64,13 +89,13 @@ impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
if let Some(inner_error) = self.0.downcast_ref::<InnerApiError>() {
match inner_error {
InnerApiError::NotFound(_) => return ApiResponse::not_found(self.0.to_string()).into_response(),
InnerApiError::NotFound(_) => return ApiResponse::<()>::not_found(self.0.to_string()).into_response(),
InnerApiError::BadRequest(_) => {
return ApiResponse::bad_request(self.0.to_string()).into_response();
return ApiResponse::<()>::bad_request(self.0.to_string()).into_response();
}
}
}
ApiResponse::internal_server_error(self.0.to_string()).into_response()
ApiResponse::<()>::internal_server_error(self.0.to_string()).into_response()
}
}

View File

@@ -1,5 +1,6 @@
use anyhow::{Context, Result, bail};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::bilibili::error::BiliError;
use crate::config::VersionedConfig;
@@ -8,7 +9,7 @@ pub struct PageAnalyzer {
info: serde_json::Value,
}
#[derive(Debug, strum::FromRepr, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[derive(Debug, strum::FromRepr, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, ToSchema, Clone)]
pub enum VideoQuality {
Quality360p = 16,
Quality480p = 32,
@@ -22,7 +23,7 @@ pub enum VideoQuality {
Quality8k = 127,
}
#[derive(Debug, Clone, Copy, strum::FromRepr, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, strum::FromRepr, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
pub enum AudioQuality {
Quality64k = 30216,
Quality132k = 30232,
@@ -54,7 +55,18 @@ impl AudioQuality {
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, strum::EnumString, strum::Display, strum::AsRefStr, PartialEq, PartialOrd, Serialize, Deserialize)]
#[derive(
Debug,
strum::EnumString,
strum::Display,
strum::AsRefStr,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
ToSchema,
Clone,
)]
pub enum VideoCodecs {
#[strum(serialize = "hev")]
HEV,
@@ -79,7 +91,7 @@ impl TryFrom<u64> for VideoCodecs {
}
// 视频流的筛选偏好
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, ToSchema, Clone)]
pub struct FilterOption {
pub video_max_quality: VideoQuality,
pub video_min_quality: VideoQuality,

View File

@@ -1,4 +1,3 @@
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
@@ -93,25 +92,25 @@ impl BiliClient {
if let Some(limiter) = self.limiter.load().as_ref() {
limiter.acquire_one().await;
}
let credential = VersionedConfig::get().load().credential.load();
self.client.request(method, url, Some(credential.as_ref()))
let credential = &VersionedConfig::get().load().credential;
self.client.request(method, url, Some(credential))
}
pub async fn check_refresh(&self, connection: &DatabaseConnection) -> Result<()> {
let credential = VersionedConfig::get().load().credential.load();
let credential = &VersionedConfig::get().load().credential;
if !credential.need_refresh(&self.client).await? {
return Ok(());
}
let new_credential = credential.refresh(&self.client).await?;
let config = VersionedConfig::get().load();
config.credential.store(Arc::new(new_credential));
config.save_to_database(connection).await?;
VersionedConfig::get()
.update_credential(new_credential, connection)
.await?;
Ok(())
}
/// 获取 wbi img用于生成请求签名
pub async fn wbi_img(&self) -> Result<WbiImg> {
let credential = VersionedConfig::get().load().credential.load();
let credential = &VersionedConfig::get().load().credential;
credential.wbi_img(&self.client).await
}
}

View File

@@ -10,6 +10,7 @@ use rsa::pkcs8::DecodePublicKey;
use rsa::sha2::Sha256;
use rsa::{Oaep, RsaPublicKey};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::bilibili::{Client, Validate};
@@ -19,7 +20,7 @@ const MIXIN_KEY_ENC_TAB: [usize; 64] = [
20, 34, 44, 52,
];
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
#[derive(Default, Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Credential {
pub sessdata: String,
pub bili_jct: String,

View File

@@ -4,13 +4,14 @@ mod lane;
use anyhow::Result;
use float_ord::FloatOrd;
use lane::Lane;
use utoipa::ToSchema;
use crate::bilibili::PageInfo;
use crate::bilibili::danmaku::canvas::lane::Collision;
use crate::bilibili::danmaku::danmu::DanmuType;
use crate::bilibili::danmaku::{Danmu, DrawEffect, Drawable};
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, ToSchema)]
pub struct DanmakuOption {
pub duration: f64,
pub font: String,

View File

@@ -73,7 +73,7 @@ impl<'a> Me<'a> {
}
fn my_id() -> String {
VersionedConfig::get().load().credential.load().dedeuserid.clone()
VersionedConfig::get().load().credential.dedeuserid.clone()
}
}

View File

@@ -2,9 +2,10 @@ use std::path::PathBuf;
use std::sync::LazyLock;
use anyhow::{Result, bail};
use arc_swap::ArcSwap;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;
use crate::bilibili::{Credential, DanmakuOption, FilterOption};
use crate::config::LegacyConfig;
@@ -15,28 +16,23 @@ use crate::utils::model::{load_db_config, save_db_config};
pub static CONFIG_DIR: LazyLock<PathBuf> =
LazyLock::new(|| dirs::config_dir().expect("No config path found").join("bili-sync"));
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, ToSchema, Validate, Clone)]
pub struct Config {
#[serde(default = "default_auth_token")]
pub auth_token: String,
#[serde(default = "default_bind_address")]
pub bind_address: String,
pub credential: ArcSwap<Credential>,
pub credential: Credential,
pub filter_option: FilterOption,
#[serde(default)]
pub danmaku_option: DanmakuOption,
pub video_name: String,
pub page_name: String,
pub interval: u64,
#[schema(value_type = String)]
pub upper_path: PathBuf,
#[serde(default)]
pub nfo_time_type: NFOTimeType,
#[serde(default)]
pub concurrent_limit: ConcurrentLimit,
#[serde(default = "default_time_format")]
pub time_format: String,
#[serde(default)]
pub cdn_sorting: bool,
pub version: u64,
}
impl Config {
@@ -59,7 +55,7 @@ impl Config {
if self.page_name.is_empty() {
errors.push("未设置 page_name 模板");
}
let credential = self.credential.load();
let credential = &self.credential;
if credential.sessdata.is_empty()
|| credential.bili_jct.is_empty()
|| credential.buvid3.is_empty()
@@ -97,7 +93,7 @@ impl Default for Config {
Self {
auth_token: default_auth_token(),
bind_address: default_bind_address(),
credential: ArcSwap::from_pointee(Credential::default()),
credential: Credential::default(),
filter_option: FilterOption::default(),
danmaku_option: DanmakuOption::default(),
video_name: "{{title}}".to_owned(),
@@ -108,6 +104,7 @@ impl Default for Config {
concurrent_limit: ConcurrentLimit::default(),
time_format: default_time_format(),
cdn_sorting: false,
version: 0,
}
}
}
@@ -128,6 +125,7 @@ impl From<LegacyConfig> for Config {
concurrent_limit: legacy.concurrent_limit,
time_format: legacy.time_format,
cdn_sorting: legacy.cdn_sorting,
version: 0,
}
}
}

View File

@@ -1,3 +0,0 @@
use std::sync::atomic::AtomicBool;
pub static DOWNLOADER_RUNNING: AtomicBool = AtomicBool::new(false);

View File

@@ -13,6 +13,7 @@ fn create_template(config: &Config) -> Result<handlebars::Handlebars<'static>> {
let mut handlebars = handlebars::Handlebars::new();
handlebars.register_helper("truncate", Box::new(truncate));
handlebars.path_safe_register("video", config.video_name.to_owned())?;
handlebars.path_safe_register("page", config.page_name.to_owned())?;
Ok(handlebars)
}

View File

@@ -2,6 +2,7 @@ use std::path::PathBuf;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::utils::filenamify::filenamify;
@@ -13,7 +14,7 @@ pub struct WatchLaterConfig {
}
/// NFO 文件使用的时间类型
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, ToSchema, Clone)]
#[serde(rename_all = "lowercase")]
pub enum NFOTimeType {
#[default]
@@ -22,7 +23,7 @@ pub enum NFOTimeType {
}
/// 并发下载相关的配置
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, ToSchema, Clone)]
pub struct ConcurrentLimit {
pub video: usize,
pub page: usize,
@@ -31,7 +32,7 @@ pub struct ConcurrentLimit {
pub download: ConcurrentDownloadLimit,
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, ToSchema, Clone)]
pub struct ConcurrentDownloadLimit {
pub enable: bool,
pub concurrency: usize,
@@ -48,7 +49,7 @@ impl Default for ConcurrentDownloadLimit {
}
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, ToSchema, Clone)]
pub struct RateLimit {
pub limit: usize,
pub duration: u64,

View File

@@ -2,7 +2,6 @@ use std::collections::HashMap;
use std::path::{Path, PathBuf};
use anyhow::Result;
use arc_swap::ArcSwap;
use sea_orm::DatabaseConnection;
use serde::de::{Deserializer, MapAccess, Visitor};
use serde::ser::SerializeMap;
@@ -20,7 +19,7 @@ pub struct LegacyConfig {
pub auth_token: String,
#[serde(default = "default_bind_address")]
pub bind_address: String,
pub credential: ArcSwap<Credential>,
pub credential: Credential,
pub filter_option: FilterOption,
#[serde(default)]
pub danmaku_option: DanmakuOption,

View File

@@ -1,7 +1,6 @@
mod args;
mod current;
mod default;
mod flag;
mod handlebar;
mod item;
mod legacy;
@@ -10,7 +9,6 @@ mod versioned_config;
pub use crate::config::args::{ARGS, version};
pub use crate::config::current::{CONFIG_DIR, Config};
pub use crate::config::flag::DOWNLOADER_RUNNING;
pub use crate::config::handlebar::TEMPLATE;
pub use crate::config::item::{NFOTimeType, PathSafeTemplate, RateLimit};
pub use crate::config::legacy::LegacyConfig;

View File

@@ -10,16 +10,19 @@ pub struct VersionedCache<T> {
inner: ArcSwap<T>,
version: AtomicU64,
builder: fn(&Config) -> Result<T>,
mutex: parking_lot::Mutex<()>,
}
impl<T> VersionedCache<T> {
pub fn new(builder: fn(&Config) -> Result<T>) -> Result<Self> {
let current_config = VersionedConfig::get().load();
let current_version = current_config.version;
let initial_value = builder(&current_config)?;
Ok(Self {
inner: ArcSwap::from_pointee(initial_value),
version: AtomicU64::new(0),
version: AtomicU64::new(current_version),
builder,
mutex: parking_lot::Mutex::new(()),
})
}
@@ -28,21 +31,23 @@ impl<T> VersionedCache<T> {
self.inner.load()
}
#[allow(dead_code)]
pub fn load_full(&self) -> Arc<T> {
self.reload_if_needed();
self.inner.load_full()
}
fn reload_if_needed(&self) {
let current_version = VersionedConfig::get().version();
let cached_version = self.version.load(Ordering::Acquire);
if current_version != cached_version {
let current_config = VersionedConfig::get().load();
if let Ok(new_value) = (self.builder)(&current_config) {
self.inner.store(Arc::new(new_value));
self.version.store(current_version, Ordering::Release);
let current_config = VersionedConfig::get().load();
let current_version = current_config.version;
let version = self.version.load(Ordering::Relaxed);
if version < current_version {
let _lock = self.mutex.lock();
if self.version.load(Ordering::Relaxed) >= current_version {
return;
}
match (self.builder)(&current_config) {
Err(e) => {
error!("Failed to rebuild versioned cache: {:?}", e);
}
Ok(new_value) => {
self.inner.store(Arc::new(new_value));
self.version.store(current_version, Ordering::Relaxed);
}
}
}
}

View File

@@ -1,24 +1,24 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Result, anyhow, bail};
use arc_swap::{ArcSwap, Guard};
use sea_orm::DatabaseConnection;
use tokio::sync::OnceCell;
use crate::bilibili::Credential;
use crate::config::{CONFIG_DIR, Config, LegacyConfig};
pub static VERSIONED_CONFIG: OnceCell<VersionedConfig> = OnceCell::const_new();
pub struct VersionedConfig {
inner: ArcSwap<Config>,
version: AtomicU64,
update_lock: tokio::sync::Mutex<()>,
}
impl VersionedConfig {
/// 初始化全局的 `VersionedConfig`,初始化失败或者已初始化过则返回错误
pub async fn init(connection: &DatabaseConnection) -> Result<()> {
let config = match Config::load_from_database(connection).await? {
let mut config = match Config::load_from_database(connection).await? {
Some(Ok(config)) => config,
Some(Err(e)) => bail!("解析数据库配置失败: {}", e),
None => {
@@ -43,6 +43,8 @@ impl VersionedConfig {
config
}
};
// version 本身不具有实际意义,仅用于并发更新时的版本控制,在初始化时可以直接清空
config.version = 0;
let versioned_config = VersionedConfig::new(config);
VERSIONED_CONFIG
.set(versioned_config)
@@ -67,7 +69,7 @@ impl VersionedConfig {
pub fn new(config: Config) -> Self {
Self {
inner: ArcSwap::from_pointee(config),
version: AtomicU64::new(1),
update_lock: tokio::sync::Mutex::new(()),
}
}
@@ -79,13 +81,40 @@ impl VersionedConfig {
self.inner.load_full()
}
pub fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
pub async fn update_credential(&self, new_credential: Credential, connection: &DatabaseConnection) -> Result<()> {
// 确保更新内容与写入数据库的操作是原子性的
let _lock = self.update_lock.lock().await;
loop {
let old_config = self.inner.load();
let mut new_config = old_config.as_ref().clone();
new_config.credential = new_credential.clone();
new_config.version += 1;
if Arc::ptr_eq(
&old_config,
&self.inner.compare_and_swap(&old_config, Arc::new(new_config)),
) {
break;
}
}
self.inner.load().save_to_database(connection).await
}
#[allow(dead_code)]
pub fn update(&self, new_config: Config) {
self.inner.store(Arc::new(new_config));
self.version.fetch_add(1, Ordering::AcqRel);
/// 外部 API 会调用这个方法,如果更新失败直接返回错误
pub async fn update(&self, mut new_config: Config, connection: &DatabaseConnection) -> Result<Arc<Config>> {
let _lock = self.update_lock.lock().await;
let old_config = self.inner.load();
if old_config.version != new_config.version {
bail!("配置版本不匹配,请刷新页面修改后重新提交");
}
new_config.version += 1;
let new_config = Arc::new(new_config);
if !Arc::ptr_eq(
&old_config,
&self.inner.compare_and_swap(&old_config, new_config.clone()),
) {
bail!("配置版本不匹配,请刷新页面修改后重新提交");
}
new_config.save_to_database(connection).await?;
Ok(new_config)
}
}

View File

@@ -2,4 +2,4 @@ mod http_server;
mod video_downloader;
pub use http_server::http_server;
pub use video_downloader::video_downloader;
pub use video_downloader::{DOWNLOADER_TASK_RUNNING, video_downloader};

View File

@@ -4,16 +4,18 @@ use sea_orm::DatabaseConnection;
use tokio::time;
use crate::bilibili::{self, BiliClient};
use crate::config::{DOWNLOADER_RUNNING, VersionedConfig};
use crate::config::VersionedConfig;
use crate::utils::model::get_enabled_video_sources;
use crate::workflow::process_video_source;
pub static DOWNLOADER_TASK_RUNNING: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: Arc<DatabaseConnection>, bili_client: Arc<BiliClient>) {
let mut anchor = chrono::Local::now().date_naive();
loop {
info!("开始执行本轮视频下载任务..");
DOWNLOADER_RUNNING.store(true, std::sync::atomic::Ordering::Relaxed);
let _lock = DOWNLOADER_TASK_RUNNING.lock().await;
let config = VersionedConfig::get().load_full();
'inner: {
if let Err(e) = config.check() {
@@ -53,7 +55,7 @@ pub async fn video_downloader(connection: Arc<DatabaseConnection>, bili_client:
}
info!("本轮任务执行完毕,等待下一轮执行");
}
DOWNLOADER_RUNNING.store(false, std::sync::atomic::Ordering::Relaxed);
drop(_lock);
time::sleep(time::Duration::from_secs(config.interval)).await;
}
}