feat: 添加部分简单 API,相应修改程序入口的初始化流程 (#251)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-02-17 16:58:51 +08:00
committed by GitHub
parent 7251802202
commit 1467c262a1
20 changed files with 648 additions and 133 deletions

View File

@@ -0,0 +1,21 @@
use axum::extract::Request;
use axum::http::HeaderMap;
use axum::middleware::Next;
use axum::response::Response;
use reqwest::StatusCode;
use crate::config::CONFIG;
pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Response, StatusCode> {
if request.uri().path().starts_with("/api") && get_token(&headers) != CONFIG.auth_token {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
fn get_token(headers: &HeaderMap) -> Option<String> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.map(Into::into)
}

View File

@@ -0,0 +1,24 @@
use anyhow::Error;
use axum::response::IntoResponse;
use reqwest::StatusCode;
pub struct ApiError(Error);
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal Server Error: {}", self.0),
)
.into_response()
}
}
impl<E> From<E> for ApiError
where
E: Into<anyhow::Error>,
{
fn from(value: E) -> Self {
Self(value.into())
}
}

View File

@@ -0,0 +1,107 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use axum::extract::{Extension, Path, Query};
use axum::Json;
use bili_sync_entity::*;
use bili_sync_migration::Expr;
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect};
use crate::api::error::ApiError;
use crate::api::payload::{PageInfo, VideoDetail, VideoInfo, VideoList, VideoListModel, VideoListModelItem};
/// 列出所有视频列表
pub async fn get_video_list_models(
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoListModel>, ApiError> {
Ok(Json(VideoListModel {
collection: collection::Entity::find()
.select_only()
.columns([collection::Column::Id, collection::Column::Name])
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
favorite: favorite::Entity::find()
.select_only()
.columns([favorite::Column::Id, favorite::Column::Name])
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
submission: submission::Entity::find()
.select_only()
.column(submission::Column::Id)
.column_as(submission::Column::UpperName, "name")
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
watch_later: watch_later::Entity::find()
.select_only()
.column(watch_later::Column::Id)
.column_as(Expr::value("稍后再看"), "name")
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
}))
}
/// 列出所有视频的基本信息(支持根据视频列表筛选,支持分页)
pub async fn list_videos(
Extension(db): Extension<Arc<DatabaseConnection>>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<VideoList>, ApiError> {
let mut query = video::Entity::find();
for (query_key, filter_column) in [
("collection", video::Column::CollectionId),
("favorite", video::Column::FavoriteId),
("submission", video::Column::SubmissionId),
("watch_later", video::Column::WatchLaterId),
] {
if let Some(value) = params.get(query_key) {
query = query.filter(filter_column.eq(value));
break;
}
}
if let Some(query_word) = params.get("q") {
query = query.filter(video::Column::Name.contains(query_word));
}
let total_count = query.clone().count(db.as_ref()).await?;
let (page, page_size) = if let (Some(page), Some(page_size)) = (params.get("page"), params.get("page_size")) {
(page.parse::<u64>()?, page_size.parse::<u64>()?)
} else {
(1, 10)
};
Ok(Json(VideoList {
videos: query
.order_by_desc(video::Column::Id)
.into_partial_model::<VideoInfo>()
.paginate(db.as_ref(), page_size)
.fetch_page(page)
.await?,
total_count,
}))
}
/// 根据 id 获取视频详细信息,包括关联的所有 page
pub async fn get_video(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoDetail>, ApiError> {
let video_info = video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref())
.await?;
let Some(video_info) = video_info else {
return Err(anyhow!("视频不存在").into());
};
let pages = page::Entity::find()
.filter(page::Column::VideoId.eq(id))
.order_by_asc(page::Column::Pid)
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.await?;
Ok(Json(VideoDetail {
video: video_info,
pages,
}))
}

View File

@@ -0,0 +1,4 @@
pub mod auth;
pub mod error;
pub mod handler;
pub mod payload;

View File

@@ -0,0 +1,45 @@
use bili_sync_entity::*;
use sea_orm::{DerivePartialModel, FromQueryResult};
use serde::Serialize;
#[derive(FromQueryResult, Serialize)]
pub struct VideoListModelItem {
id: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoListModel {
pub collection: Vec<VideoListModelItem>,
pub favorite: Vec<VideoListModelItem>,
pub submission: Vec<VideoListModelItem>,
pub watch_later: Vec<VideoListModelItem>,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "video::Entity")]
pub struct VideoInfo {
id: i32,
name: String,
upper_name: String,
}
#[derive(Serialize)]
pub struct VideoList {
pub videos: Vec<VideoInfo>,
pub total_count: u64,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "page::Entity")]
pub struct PageInfo {
id: i32,
pid: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoDetail {
pub video: VideoInfo,
pub pages: Vec<PageInfo>,
}

View File

@@ -48,11 +48,10 @@ fn load_config() -> Config {
panic!("加载配置文件失败,错误为: {err}");
}
warn!("配置文件不存在,使用默认配置...");
let default_config = Config::default();
default_config.save().expect("保存默认配置时遇到错误");
info!("已将默认配置写入 {}", CONFIG_DIR.join("config.toml").display());
default_config
Config::default()
});
info!("配置文件加载完毕,覆盖刷新原有配置");
config.save().expect("保存默认配置时遇到错误");
info!("检查配置文件..");
config.check();
info!("配置文件检查通过");

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use anyhow::Result;
use arc_swap::ArcSwapOption;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
mod clap;
@@ -20,8 +21,27 @@ fn default_time_format() -> String {
"%Y-%m-%d".to_string()
}
/// 默认的 auth_token 实现,生成随机 16 位字符串
fn default_auth_token() -> Option<String> {
let byte_choices = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=";
let mut rng = rand::thread_rng();
Some(
(0..16)
.map(|_| *(byte_choices.choose(&mut rng).expect("choose byte failed")) as char)
.collect(),
)
}
fn default_bind_address() -> String {
"0.0.0.0:12345".to_string()
}
#[derive(Serialize, Deserialize)]
pub struct Config {
#[serde(default = "default_auth_token")]
pub auth_token: Option<String>,
#[serde(default = "default_bind_address")]
pub bind_address: String,
pub credential: ArcSwapOption<Credential>,
pub filter_option: FilterOption,
#[serde(default)]
@@ -52,6 +72,8 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Self {
auth_token: default_auth_token(),
bind_address: default_bind_address(),
credential: ArcSwapOption::from(Some(Arc::new(Credential::default()))),
filter_option: FilterOption::default(),
danmaku_option: DanmakuOption::default(),

View File

@@ -8,7 +8,7 @@ fn database_url() -> String {
format!("sqlite://{}?mode=rwc", CONFIG_DIR.join("data.sqlite").to_string_lossy())
}
pub async fn database_connection() -> Result<DatabaseConnection> {
async fn database_connection() -> Result<DatabaseConnection> {
let mut option = ConnectOptions::new(database_url());
option
.max_connections(100)
@@ -17,9 +17,15 @@ pub async fn database_connection() -> Result<DatabaseConnection> {
Ok(Database::connect(option).await?)
}
pub async fn migrate_database() -> Result<()> {
async fn migrate_database() -> Result<()> {
// 注意此处使用内部构造的 DatabaseConnection而不是通过 database_connection() 获取
// 这是因为使用多个连接的 Connection 会导致奇怪的迁移顺序问题,而使用默认的连接选项不会
let connection = Database::connect(database_url()).await?;
Ok(Migrator::up(&connection, None).await?)
}
/// 进行数据库迁移并获取数据库连接,供外部使用
pub async fn setup_database() -> DatabaseConnection {
migrate_database().await.expect("数据库迁移失败");
database_connection().await.expect("获取数据库连接失败")
}

View File

@@ -2,36 +2,61 @@
extern crate tracing;
mod adapter;
mod api;
mod bilibili;
mod config;
mod database;
mod downloader;
mod error;
mod task;
mod utils;
mod workflow;
use std::io;
use std::path::PathBuf;
use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;
use once_cell::sync::Lazy;
use sea_orm::DatabaseConnection;
use tokio::{signal, time};
use task::{http_server, video_downloader};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::adapter::Args;
use crate::bilibili::BiliClient;
use crate::config::{ARGS, CONFIG};
use crate::database::{database_connection, migrate_database};
use crate::database::setup_database;
use crate::utils::init_logger;
use crate::workflow::process_video_list;
use crate::utils::signal::terminate;
#[tokio::main]
async fn main() {
init();
let connection = setup_database().await;
let bili_client = BiliClient::new();
let params = collect_task_params();
let task = spawn_periodic_task(bili_client, params, connection);
handle_shutdown(task).await;
let connection = Arc::new(setup_database().await);
let token = CancellationToken::new();
let tracker = TaskTracker::new();
spawn_task("HTTP 服务", http_server(connection.clone()), &tracker, token.clone());
spawn_task("定时下载", video_downloader(connection), &tracker, token.clone());
tracker.close();
handle_shutdown(tracker, token).await
}
fn spawn_task(
task_name: &'static str,
task: impl Future<Output = impl Debug> + Send + 'static,
tracker: &TaskTracker,
token: CancellationToken,
) {
tracker.spawn(async move {
tokio::select! {
res = task => {
error!("「{}」异常结束,返回结果为:「{:?}」,取消其它仍在执行的任务..", task_name, res);
token.cancel();
},
_ = token.cancelled() => {
info!("「{}」接收到取消信号,终止运行..", task_name);
}
}
});
}
/// 初始化日志系统,加载命令行参数和配置文件
@@ -41,100 +66,16 @@ fn init() {
Lazy::force(&CONFIG);
}
/// 迁移数据库并获取数据库连接
async fn setup_database() -> DatabaseConnection {
migrate_database().await.expect("数据库迁移失败");
database_connection().await.expect("获取数据库连接失败")
}
/// 收集任务执行所需的参数(下载类型和保存路径)
fn collect_task_params() -> Vec<(Args<'static>, &'static PathBuf)> {
let mut params = Vec::new();
CONFIG
.favorite_list
.iter()
.for_each(|(fid, path)| params.push((Args::Favorite { fid }, path)));
CONFIG
.collection_list
.iter()
.for_each(|(collection_item, path)| params.push((Args::Collection { collection_item }, path)));
if CONFIG.watch_later.enabled {
params.push((Args::WatchLater, &CONFIG.watch_later.path));
}
CONFIG
.submission_list
.iter()
.for_each(|(upper_id, path)| params.push((Args::Submission { upper_id }, path)));
params
}
/// 启动周期下载的任务
fn spawn_periodic_task(
bili_client: BiliClient,
params: Vec<(Args<'static>, &'static PathBuf)>,
connection: DatabaseConnection,
) -> tokio::task::JoinHandle<()> {
let mut anchor = chrono::Local::now().date_naive();
tokio::spawn(async move {
loop {
'inner: {
match bili_client.wbi_img().await.map(|wbi_img| wbi_img.into()) {
Ok(Some(mixin_key)) => bilibili::set_global_mixin_key(mixin_key),
Ok(_) => {
error!("解析 mixin key 失败,等待下一轮执行");
break 'inner;
}
Err(e) => {
error!("获取 mixin key 遇到错误:{e},等待下一轮执行");
break 'inner;
}
};
if anchor != chrono::Local::now().date_naive() {
if let Err(e) = bili_client.check_refresh().await {
error!("检查刷新 Credential 遇到错误:{e},等待下一轮执行");
break 'inner;
}
anchor = chrono::Local::now().date_naive();
}
for (args, path) in &params {
if let Err(e) = process_video_list(*args, &bili_client, path, &connection).await {
error!("处理过程遇到错误:{e}");
}
}
info!("本轮任务执行完毕,等待下一轮执行");
}
time::sleep(time::Duration::from_secs(CONFIG.interval)).await;
async fn handle_shutdown(tracker: TaskTracker, token: CancellationToken) {
tokio::select! {
_ = tracker.wait() => {
error!("所有任务均已终止,程序退出")
}
})
}
/// 处理终止信号
async fn handle_shutdown(task: tokio::task::JoinHandle<()>) {
let _ = terminate().await;
info!("接收到终止信号,正在终止任务..");
task.abort();
match task.await {
Err(e) if !e.is_cancelled() => error!("任务终止时遇到错误:{}", e),
_ => {
info!("任务成功终止,退出程序..");
_ = terminate() => {
info!("接收到终止信号,正在终止任务..");
token.cancel();
tracker.wait().await;
info!("所有任务均已终止,程序退出");
}
}
}
#[cfg(target_family = "windows")]
async fn terminate() -> io::Result<()> {
signal::ctrl_c().await
}
/// ctrl + c 发送的是 SIGINT 信号docker stop 发送的是 SIGTERM 信号,都需要处理
#[cfg(target_family = "unix")]
async fn terminate() -> io::Result<()> {
use tokio::select;
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())?;
let mut int = signal::unix::signal(signal::unix::SignalKind::interrupt())?;
select! {
_ = term.recv() => Ok(()),
_ = int.recv() => Ok(()),
}
}

View File

@@ -0,0 +1,51 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use axum::extract::Request;
use axum::http::{header, Uri};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{middleware, Extension, Router, ServiceExt};
use reqwest::StatusCode;
use rust_embed::Embed;
use sea_orm::DatabaseConnection;
use tower::Layer;
use tower_http::normalize_path::NormalizePathLayer;
use crate::api::auth;
use crate::api::handler::{get_video, get_video_list_models, list_videos};
use crate::config::CONFIG;
#[derive(Embed)]
#[folder = "../../web/build"]
struct Asset;
pub async fn http_server(database_connection: Arc<DatabaseConnection>) -> Result<()> {
let app = Router::new()
.route("/api/videos", get(list_videos))
.route("/api/videos/{video_id}", get(get_video))
.route("/api/video-list-models", get(get_video_list_models))
.fallback_service(get(frontend_files))
.layer(Extension(database_connection))
.layer(middleware::from_fn(auth::auth));
let app = NormalizePathLayer::trim_trailing_slash().layer(app);
let listener = tokio::net::TcpListener::bind(&CONFIG.bind_address)
.await
.context("bind address failed")?;
info!("开始监听 http 服务: http://{}", CONFIG.bind_address);
Ok(axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?)
}
async fn frontend_files(uri: Uri) -> impl IntoResponse {
let mut path = uri.path().trim_start_matches('/');
if path.is_empty() {
path = "index.html";
}
match Asset::get(path) {
Some(content) => {
let mime = mime_guess::from_path(path).first_or_octet_stream();
([(header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
}
None => (StatusCode::NOT_FOUND, "404 Not Found").into_response(),
}
}

View File

@@ -0,0 +1,5 @@
mod http_server;
mod video_downloader;
pub use http_server::http_server;
pub use video_downloader::video_downloader;

View File

@@ -0,0 +1,67 @@
use std::path::PathBuf;
use std::sync::Arc;
use sea_orm::DatabaseConnection;
use tokio::time;
use crate::adapter::Args;
use crate::bilibili::{self, BiliClient};
use crate::config::CONFIG;
use crate::workflow::process_video_list;
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: Arc<DatabaseConnection>) {
let mut anchor = chrono::Local::now().date_naive();
let bili_client = BiliClient::new();
let params = collect_task_params();
loop {
'inner: {
match bili_client.wbi_img().await.map(|wbi_img| wbi_img.into()) {
Ok(Some(mixin_key)) => bilibili::set_global_mixin_key(mixin_key),
Ok(_) => {
error!("解析 mixin key 失败,等待下一轮执行");
break 'inner;
}
Err(e) => {
error!("获取 mixin key 遇到错误:{e},等待下一轮执行");
break 'inner;
}
};
if anchor != chrono::Local::now().date_naive() {
if let Err(e) = bili_client.check_refresh().await {
error!("检查刷新 Credential 遇到错误:{e},等待下一轮执行");
break 'inner;
}
anchor = chrono::Local::now().date_naive();
}
for (args, path) in &params {
if let Err(e) = process_video_list(*args, &bili_client, path, &connection).await {
error!("处理过程遇到错误:{e}");
}
}
info!("本轮任务执行完毕,等待下一轮执行");
}
time::sleep(time::Duration::from_secs(CONFIG.interval)).await;
}
}
/// 构造下载视频任务执行所需的参数(下载类型和保存路径)
fn collect_task_params() -> Vec<(Args<'static>, &'static PathBuf)> {
let mut params = Vec::new();
CONFIG
.favorite_list
.iter()
.for_each(|(fid, path)| params.push((Args::Favorite { fid }, path)));
CONFIG
.collection_list
.iter()
.for_each(|(collection_item, path)| params.push((Args::Collection { collection_item }, path)));
if CONFIG.watch_later.enabled {
params.push((Args::WatchLater, &CONFIG.watch_later.path));
}
CONFIG
.submission_list
.iter()
.for_each(|(upper_id, path)| params.push((Args::Submission { upper_id }, path)));
params
}

View File

@@ -3,6 +3,7 @@ pub mod filenamify;
pub mod format_arg;
pub mod model;
pub mod nfo;
pub mod signal;
pub mod status;
use tracing_subscriber::util::SubscriberInitExt;

View File

@@ -0,0 +1,21 @@
use std::io;
use tokio::signal;
#[cfg(target_family = "windows")]
pub async fn terminate() -> io::Result<()> {
signal::ctrl_c().await
}
/// ctrl + c 发送的是 SIGINT 信号docker stop 发送的是 SIGTERM 信号,都需要处理
#[cfg(target_family = "unix")]
pub async fn terminate() -> io::Result<()> {
use tokio::select;
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())?;
let mut int = signal::unix::signal(signal::unix::SignalKind::interrupt())?;
select! {
_ = term.recv() => Ok(()),
_ = int.recv() => Ok(()),
}
}

View File

@@ -1,7 +1,7 @@
use anyhow::Result;
static STATUS_MAX_RETRY: u32 = 0b100;
static STATUS_OK: u32 = 0b111;
pub(super) static STATUS_MAX_RETRY: u32 = 0b100;
pub(super) static STATUS_OK: u32 = 0b111;
pub static STATUS_COMPLETED: u32 = 1 << 31;
/// 用来表示下载的状态,不想写太多列了,所以仅使用一个 u32 表示。