feat: 添加部分简单 API,相应修改程序入口的初始化流程 (#251)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-02-17 16:58:51 +08:00
committed by GitHub
parent 7251802202
commit 1467c262a1
20 changed files with 648 additions and 133 deletions

195
Cargo.lock generated
View File

@@ -336,6 +336,72 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]]
name = "axum"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [
"axum-core",
"axum-macros",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733"
dependencies = [
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.96",
]
[[package]]
name = "backtrace"
version = "0.3.71"
@@ -385,6 +451,7 @@ dependencies = [
"arc-swap",
"assert_matches",
"async-stream",
"axum",
"bili_sync_entity",
"bili_sync_migration",
"chrono",
@@ -400,6 +467,7 @@ dependencies = [
"leaky-bucket",
"md5",
"memchr",
"mime_guess",
"once_cell",
"prost",
"quick-xml",
@@ -407,6 +475,7 @@ dependencies = [
"regex",
"reqwest",
"rsa",
"rust-embed",
"sea-orm",
"serde",
"serde_json",
@@ -416,6 +485,8 @@ dependencies = [
"tokio",
"tokio-util",
"toml",
"tower",
"tower-http",
"tracing",
"tracing-subscriber",
]
@@ -1246,6 +1317,12 @@ dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.2"
@@ -1357,6 +1434,12 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.5.2"
@@ -1370,6 +1453,7 @@ dependencies = [
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@@ -1610,6 +1694,12 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "md-5"
version = "0.10.6"
@@ -1638,6 +1728,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "miniz_oxide"
version = "0.7.2"
@@ -2378,6 +2478,40 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rust-embed"
version = "8.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "8.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.96",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "8.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rust_decimal"
version = "1.35.0"
@@ -2472,6 +2606,15 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@@ -2675,6 +2818,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.7"
@@ -3282,6 +3435,8 @@ dependencies = [
"bytes",
"futures-core",
"futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite",
"tokio",
]
@@ -3344,6 +3499,21 @@ dependencies = [
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
dependencies = [
"bitflags 2.5.0",
"bytes",
"http",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
@@ -3439,6 +3609,12 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.15"
@@ -3516,6 +3692,16 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@@ -3661,6 +3847,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@@ -21,6 +21,7 @@ assert_matches = "1.5"
async-std = { version = "1.13.0", features = ["attributes", "tokio1"] }
async-stream = "0.3.6"
async-trait = "0.1.85"
axum = { version = "0.8.1", features = ["macros"] }
chrono = { version = "0.4.39", features = ["serde"] }
clap = { version = "4.5.26", features = ["env"] }
cookie = "0.18.1"
@@ -34,25 +35,27 @@ hex = "0.4.3"
leaky-bucket = "1.1.2"
md5 = "0.7.0"
memchr = "2.7.4"
mime_guess = "=2.0.5"
once_cell = "1.20.2"
prost = "0.13.4"
quick-xml = { version = "0.37.2", features = ["async-tokio"] }
rand = "0.8.5"
regex = "1.11.1"
reqwest = { version = "0.12.12", features = [
"charset",
"cookies",
"gzip",
"http2",
"json",
"rustls-tls",
"stream",
"charset",
"cookies",
"gzip",
"http2",
"json",
"rustls-tls",
"stream",
], default-features = false }
rsa = { version = "0.9.7", features = ["sha2"] }
rust-embed = "8.5.0"
sea-orm = { version = "1.1.4", features = [
"macros",
"runtime-tokio-rustls",
"sqlx-sqlite",
"macros",
"runtime-tokio-rustls",
"sqlx-sqlite",
] }
sea-orm-migration = { version = "1.1.4", features = [] }
serde = { version = "1.0.217", features = ["derive"] }
@@ -61,8 +64,10 @@ serde_urlencoded = "0.7.1"
strum = { version = "0.26.3", features = ["derive"] }
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["full"] }
tokio-util = { version = "0.7.13", features = ["io"] }
tokio-util = { version = "0.7.13", features = ["io", "rt"] }
toml = "0.8.19"
tower = "0.5.2"
tower-http = { version = "0.6.2", features = ["normalize-path"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["chrono"] }
@@ -73,8 +78,8 @@ tag-prefix = ""
pre-release-commit-message = "chore: 发布 bili-sync {{version}}"
publish = false
pre-release-replacements = [
{ file = "../../docs/.vitepress/config.mts", search = "\"v[0-9\\.]+\"", replace = "\"v{{version}}\"", exactly = 1 },
{ file = "../../docs/introduction.md", search = " v[0-9\\.]+", replace = " v{{version}}", exactly = 1 },
{ file = "../../docs/.vitepress/config.mts", search = "\"v[0-9\\.]+\"", replace = "\"v{{version}}\"", exactly = 1 },
{ file = "../../docs/introduction.md", search = " v[0-9\\.]+", replace = " v{{version}}", exactly = 1 },
]
[profile.release]

View File

@@ -9,13 +9,9 @@ build-docker: build
docker build . -t bili-sync-rs-local --build-arg="TARGETPLATFORM=linux/amd64"
just clean
copy-config:
rm -rf ~/.config/bili-sync
cp -r ~/.config/nas/bili-sync-rs ~/.config/bili-sync
sed -i -e 's/\/Bilibilis/\/Test_Bilibilis/g' -e 's/.config\/nas/.config\/test_nas/g' ~/.config/bili-sync/config.toml
run:
cd ./web && bun run build && cd ..
cargo run
debug: copy-config
debug:
just run

View File

@@ -12,6 +12,7 @@ readme = "../../README.md"
anyhow = { workspace = true }
arc-swap = { workspace = true }
async-stream = { workspace = true }
axum = { workspace = true }
bili_sync_entity = { workspace = true }
bili_sync_migration = { workspace = true }
chrono = { workspace = true }
@@ -27,6 +28,7 @@ hex = { workspace = true }
leaky-bucket = { workspace = true }
md5 = { workspace = true }
memchr = { workspace = true }
mime_guess = { workspace = true }
once_cell = { workspace = true }
prost = { workspace = true }
quick-xml = { workspace = true }
@@ -34,6 +36,7 @@ rand = { workspace = true }
regex = { workspace = true }
reqwest = { workspace = true }
rsa = { workspace = true }
rust-embed = { workspace = true }
sea-orm = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@@ -43,6 +46,8 @@ thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
toml = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }

View File

@@ -0,0 +1,21 @@
use axum::extract::Request;
use axum::http::HeaderMap;
use axum::middleware::Next;
use axum::response::Response;
use reqwest::StatusCode;
use crate::config::CONFIG;
pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Response, StatusCode> {
if request.uri().path().starts_with("/api") && get_token(&headers) != CONFIG.auth_token {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
fn get_token(headers: &HeaderMap) -> Option<String> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.map(Into::into)
}

View File

@@ -0,0 +1,24 @@
use anyhow::Error;
use axum::response::IntoResponse;
use reqwest::StatusCode;
pub struct ApiError(Error);
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal Server Error: {}", self.0),
)
.into_response()
}
}
impl<E> From<E> for ApiError
where
E: Into<anyhow::Error>,
{
fn from(value: E) -> Self {
Self(value.into())
}
}

View File

@@ -0,0 +1,107 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use axum::extract::{Extension, Path, Query};
use axum::Json;
use bili_sync_entity::*;
use bili_sync_migration::Expr;
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect};
use crate::api::error::ApiError;
use crate::api::payload::{PageInfo, VideoDetail, VideoInfo, VideoList, VideoListModel, VideoListModelItem};
/// 列出所有视频列表
pub async fn get_video_list_models(
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoListModel>, ApiError> {
Ok(Json(VideoListModel {
collection: collection::Entity::find()
.select_only()
.columns([collection::Column::Id, collection::Column::Name])
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
favorite: favorite::Entity::find()
.select_only()
.columns([favorite::Column::Id, favorite::Column::Name])
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
submission: submission::Entity::find()
.select_only()
.column(submission::Column::Id)
.column_as(submission::Column::UpperName, "name")
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
watch_later: watch_later::Entity::find()
.select_only()
.column(watch_later::Column::Id)
.column_as(Expr::value("稍后再看"), "name")
.into_model::<VideoListModelItem>()
.all(db.as_ref())
.await?,
}))
}
/// 列出所有视频的基本信息(支持根据视频列表筛选,支持分页)
pub async fn list_videos(
Extension(db): Extension<Arc<DatabaseConnection>>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<VideoList>, ApiError> {
let mut query = video::Entity::find();
for (query_key, filter_column) in [
("collection", video::Column::CollectionId),
("favorite", video::Column::FavoriteId),
("submission", video::Column::SubmissionId),
("watch_later", video::Column::WatchLaterId),
] {
if let Some(value) = params.get(query_key) {
query = query.filter(filter_column.eq(value));
break;
}
}
if let Some(query_word) = params.get("q") {
query = query.filter(video::Column::Name.contains(query_word));
}
let total_count = query.clone().count(db.as_ref()).await?;
let (page, page_size) = if let (Some(page), Some(page_size)) = (params.get("page"), params.get("page_size")) {
(page.parse::<u64>()?, page_size.parse::<u64>()?)
} else {
(1, 10)
};
Ok(Json(VideoList {
videos: query
.order_by_desc(video::Column::Id)
.into_partial_model::<VideoInfo>()
.paginate(db.as_ref(), page_size)
.fetch_page(page)
.await?,
total_count,
}))
}
/// 根据 id 获取视频详细信息,包括关联的所有 page
pub async fn get_video(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoDetail>, ApiError> {
let video_info = video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref())
.await?;
let Some(video_info) = video_info else {
return Err(anyhow!("视频不存在").into());
};
let pages = page::Entity::find()
.filter(page::Column::VideoId.eq(id))
.order_by_asc(page::Column::Pid)
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.await?;
Ok(Json(VideoDetail {
video: video_info,
pages,
}))
}

View File

@@ -0,0 +1,4 @@
pub mod auth;
pub mod error;
pub mod handler;
pub mod payload;

View File

@@ -0,0 +1,45 @@
use bili_sync_entity::*;
use sea_orm::{DerivePartialModel, FromQueryResult};
use serde::Serialize;
#[derive(FromQueryResult, Serialize)]
pub struct VideoListModelItem {
id: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoListModel {
pub collection: Vec<VideoListModelItem>,
pub favorite: Vec<VideoListModelItem>,
pub submission: Vec<VideoListModelItem>,
pub watch_later: Vec<VideoListModelItem>,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "video::Entity")]
pub struct VideoInfo {
id: i32,
name: String,
upper_name: String,
}
#[derive(Serialize)]
pub struct VideoList {
pub videos: Vec<VideoInfo>,
pub total_count: u64,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "page::Entity")]
pub struct PageInfo {
id: i32,
pid: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoDetail {
pub video: VideoInfo,
pub pages: Vec<PageInfo>,
}

View File

@@ -48,11 +48,10 @@ fn load_config() -> Config {
panic!("加载配置文件失败,错误为: {err}");
}
warn!("配置文件不存在,使用默认配置...");
let default_config = Config::default();
default_config.save().expect("保存默认配置时遇到错误");
info!("已将默认配置写入 {}", CONFIG_DIR.join("config.toml").display());
default_config
Config::default()
});
info!("配置文件加载完毕,覆盖刷新原有配置");
config.save().expect("保存默认配置时遇到错误");
info!("检查配置文件..");
config.check();
info!("配置文件检查通过");

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use anyhow::Result;
use arc_swap::ArcSwapOption;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
mod clap;
@@ -20,8 +21,27 @@ fn default_time_format() -> String {
"%Y-%m-%d".to_string()
}
/// 默认的 auth_token 实现,生成随机 16 位字符串
fn default_auth_token() -> Option<String> {
let byte_choices = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=";
let mut rng = rand::thread_rng();
Some(
(0..16)
.map(|_| *(byte_choices.choose(&mut rng).expect("choose byte failed")) as char)
.collect(),
)
}
fn default_bind_address() -> String {
"0.0.0.0:12345".to_string()
}
#[derive(Serialize, Deserialize)]
pub struct Config {
#[serde(default = "default_auth_token")]
pub auth_token: Option<String>,
#[serde(default = "default_bind_address")]
pub bind_address: String,
pub credential: ArcSwapOption<Credential>,
pub filter_option: FilterOption,
#[serde(default)]
@@ -52,6 +72,8 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Self {
auth_token: default_auth_token(),
bind_address: default_bind_address(),
credential: ArcSwapOption::from(Some(Arc::new(Credential::default()))),
filter_option: FilterOption::default(),
danmaku_option: DanmakuOption::default(),

View File

@@ -8,7 +8,7 @@ fn database_url() -> String {
format!("sqlite://{}?mode=rwc", CONFIG_DIR.join("data.sqlite").to_string_lossy())
}
pub async fn database_connection() -> Result<DatabaseConnection> {
async fn database_connection() -> Result<DatabaseConnection> {
let mut option = ConnectOptions::new(database_url());
option
.max_connections(100)
@@ -17,9 +17,15 @@ pub async fn database_connection() -> Result<DatabaseConnection> {
Ok(Database::connect(option).await?)
}
pub async fn migrate_database() -> Result<()> {
async fn migrate_database() -> Result<()> {
// 注意此处使用内部构造的 DatabaseConnection而不是通过 database_connection() 获取
// 这是因为使用多个连接的 Connection 会导致奇怪的迁移顺序问题,而使用默认的连接选项不会
let connection = Database::connect(database_url()).await?;
Ok(Migrator::up(&connection, None).await?)
}
/// 进行数据库迁移并获取数据库连接,供外部使用
pub async fn setup_database() -> DatabaseConnection {
migrate_database().await.expect("数据库迁移失败");
database_connection().await.expect("获取数据库连接失败")
}

View File

@@ -2,36 +2,61 @@
extern crate tracing;
mod adapter;
mod api;
mod bilibili;
mod config;
mod database;
mod downloader;
mod error;
mod task;
mod utils;
mod workflow;
use std::io;
use std::path::PathBuf;
use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;
use once_cell::sync::Lazy;
use sea_orm::DatabaseConnection;
use tokio::{signal, time};
use task::{http_server, video_downloader};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::adapter::Args;
use crate::bilibili::BiliClient;
use crate::config::{ARGS, CONFIG};
use crate::database::{database_connection, migrate_database};
use crate::database::setup_database;
use crate::utils::init_logger;
use crate::workflow::process_video_list;
use crate::utils::signal::terminate;
#[tokio::main]
async fn main() {
init();
let connection = setup_database().await;
let bili_client = BiliClient::new();
let params = collect_task_params();
let task = spawn_periodic_task(bili_client, params, connection);
handle_shutdown(task).await;
let connection = Arc::new(setup_database().await);
let token = CancellationToken::new();
let tracker = TaskTracker::new();
spawn_task("HTTP 服务", http_server(connection.clone()), &tracker, token.clone());
spawn_task("定时下载", video_downloader(connection), &tracker, token.clone());
tracker.close();
handle_shutdown(tracker, token).await
}
fn spawn_task(
task_name: &'static str,
task: impl Future<Output = impl Debug> + Send + 'static,
tracker: &TaskTracker,
token: CancellationToken,
) {
tracker.spawn(async move {
tokio::select! {
res = task => {
error!("「{}」异常结束,返回结果为:「{:?}」,取消其它仍在执行的任务..", task_name, res);
token.cancel();
},
_ = token.cancelled() => {
info!("「{}」接收到取消信号,终止运行..", task_name);
}
}
});
}
/// 初始化日志系统,加载命令行参数和配置文件
@@ -41,100 +66,16 @@ fn init() {
Lazy::force(&CONFIG);
}
/// 迁移数据库并获取数据库连接
async fn setup_database() -> DatabaseConnection {
migrate_database().await.expect("数据库迁移失败");
database_connection().await.expect("获取数据库连接失败")
}
/// 收集任务执行所需的参数(下载类型和保存路径)
fn collect_task_params() -> Vec<(Args<'static>, &'static PathBuf)> {
let mut params = Vec::new();
CONFIG
.favorite_list
.iter()
.for_each(|(fid, path)| params.push((Args::Favorite { fid }, path)));
CONFIG
.collection_list
.iter()
.for_each(|(collection_item, path)| params.push((Args::Collection { collection_item }, path)));
if CONFIG.watch_later.enabled {
params.push((Args::WatchLater, &CONFIG.watch_later.path));
}
CONFIG
.submission_list
.iter()
.for_each(|(upper_id, path)| params.push((Args::Submission { upper_id }, path)));
params
}
/// 启动周期下载的任务
fn spawn_periodic_task(
bili_client: BiliClient,
params: Vec<(Args<'static>, &'static PathBuf)>,
connection: DatabaseConnection,
) -> tokio::task::JoinHandle<()> {
let mut anchor = chrono::Local::now().date_naive();
tokio::spawn(async move {
loop {
'inner: {
match bili_client.wbi_img().await.map(|wbi_img| wbi_img.into()) {
Ok(Some(mixin_key)) => bilibili::set_global_mixin_key(mixin_key),
Ok(_) => {
error!("解析 mixin key 失败,等待下一轮执行");
break 'inner;
}
Err(e) => {
error!("获取 mixin key 遇到错误:{e},等待下一轮执行");
break 'inner;
}
};
if anchor != chrono::Local::now().date_naive() {
if let Err(e) = bili_client.check_refresh().await {
error!("检查刷新 Credential 遇到错误:{e},等待下一轮执行");
break 'inner;
}
anchor = chrono::Local::now().date_naive();
}
for (args, path) in &params {
if let Err(e) = process_video_list(*args, &bili_client, path, &connection).await {
error!("处理过程遇到错误:{e}");
}
}
info!("本轮任务执行完毕,等待下一轮执行");
}
time::sleep(time::Duration::from_secs(CONFIG.interval)).await;
async fn handle_shutdown(tracker: TaskTracker, token: CancellationToken) {
tokio::select! {
_ = tracker.wait() => {
error!("所有任务均已终止,程序退出")
}
})
}
/// 处理终止信号
async fn handle_shutdown(task: tokio::task::JoinHandle<()>) {
let _ = terminate().await;
info!("接收到终止信号,正在终止任务..");
task.abort();
match task.await {
Err(e) if !e.is_cancelled() => error!("任务终止时遇到错误:{}", e),
_ => {
info!("任务成功终止,退出程序..");
_ = terminate() => {
info!("接收到终止信号,正在终止任务..");
token.cancel();
tracker.wait().await;
info!("所有任务均已终止,程序退出");
}
}
}
#[cfg(target_family = "windows")]
async fn terminate() -> io::Result<()> {
signal::ctrl_c().await
}
/// ctrl + c 发送的是 SIGINT 信号docker stop 发送的是 SIGTERM 信号,都需要处理
#[cfg(target_family = "unix")]
async fn terminate() -> io::Result<()> {
use tokio::select;
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())?;
let mut int = signal::unix::signal(signal::unix::SignalKind::interrupt())?;
select! {
_ = term.recv() => Ok(()),
_ = int.recv() => Ok(()),
}
}

View File

@@ -0,0 +1,51 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use axum::extract::Request;
use axum::http::{header, Uri};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{middleware, Extension, Router, ServiceExt};
use reqwest::StatusCode;
use rust_embed::Embed;
use sea_orm::DatabaseConnection;
use tower::Layer;
use tower_http::normalize_path::NormalizePathLayer;
use crate::api::auth;
use crate::api::handler::{get_video, get_video_list_models, list_videos};
use crate::config::CONFIG;
#[derive(Embed)]
#[folder = "../../web/build"]
struct Asset;
pub async fn http_server(database_connection: Arc<DatabaseConnection>) -> Result<()> {
let app = Router::new()
.route("/api/videos", get(list_videos))
.route("/api/videos/{video_id}", get(get_video))
.route("/api/video-list-models", get(get_video_list_models))
.fallback_service(get(frontend_files))
.layer(Extension(database_connection))
.layer(middleware::from_fn(auth::auth));
let app = NormalizePathLayer::trim_trailing_slash().layer(app);
let listener = tokio::net::TcpListener::bind(&CONFIG.bind_address)
.await
.context("bind address failed")?;
info!("开始监听 http 服务: http://{}", CONFIG.bind_address);
Ok(axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?)
}
async fn frontend_files(uri: Uri) -> impl IntoResponse {
let mut path = uri.path().trim_start_matches('/');
if path.is_empty() {
path = "index.html";
}
match Asset::get(path) {
Some(content) => {
let mime = mime_guess::from_path(path).first_or_octet_stream();
([(header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
}
None => (StatusCode::NOT_FOUND, "404 Not Found").into_response(),
}
}

View File

@@ -0,0 +1,5 @@
mod http_server;
mod video_downloader;
pub use http_server::http_server;
pub use video_downloader::video_downloader;

View File

@@ -0,0 +1,67 @@
use std::path::PathBuf;
use std::sync::Arc;
use sea_orm::DatabaseConnection;
use tokio::time;
use crate::adapter::Args;
use crate::bilibili::{self, BiliClient};
use crate::config::CONFIG;
use crate::workflow::process_video_list;
/// 启动周期下载视频的任务
pub async fn video_downloader(connection: Arc<DatabaseConnection>) {
let mut anchor = chrono::Local::now().date_naive();
let bili_client = BiliClient::new();
let params = collect_task_params();
loop {
'inner: {
match bili_client.wbi_img().await.map(|wbi_img| wbi_img.into()) {
Ok(Some(mixin_key)) => bilibili::set_global_mixin_key(mixin_key),
Ok(_) => {
error!("解析 mixin key 失败,等待下一轮执行");
break 'inner;
}
Err(e) => {
error!("获取 mixin key 遇到错误:{e},等待下一轮执行");
break 'inner;
}
};
if anchor != chrono::Local::now().date_naive() {
if let Err(e) = bili_client.check_refresh().await {
error!("检查刷新 Credential 遇到错误:{e},等待下一轮执行");
break 'inner;
}
anchor = chrono::Local::now().date_naive();
}
for (args, path) in &params {
if let Err(e) = process_video_list(*args, &bili_client, path, &connection).await {
error!("处理过程遇到错误:{e}");
}
}
info!("本轮任务执行完毕,等待下一轮执行");
}
time::sleep(time::Duration::from_secs(CONFIG.interval)).await;
}
}
/// 构造下载视频任务执行所需的参数(下载类型和保存路径)
fn collect_task_params() -> Vec<(Args<'static>, &'static PathBuf)> {
let mut params = Vec::new();
CONFIG
.favorite_list
.iter()
.for_each(|(fid, path)| params.push((Args::Favorite { fid }, path)));
CONFIG
.collection_list
.iter()
.for_each(|(collection_item, path)| params.push((Args::Collection { collection_item }, path)));
if CONFIG.watch_later.enabled {
params.push((Args::WatchLater, &CONFIG.watch_later.path));
}
CONFIG
.submission_list
.iter()
.for_each(|(upper_id, path)| params.push((Args::Submission { upper_id }, path)));
params
}

View File

@@ -3,6 +3,7 @@ pub mod filenamify;
pub mod format_arg;
pub mod model;
pub mod nfo;
pub mod signal;
pub mod status;
use tracing_subscriber::util::SubscriberInitExt;

View File

@@ -0,0 +1,21 @@
use std::io;
use tokio::signal;
#[cfg(target_family = "windows")]
pub async fn terminate() -> io::Result<()> {
signal::ctrl_c().await
}
/// ctrl + c 发送的是 SIGINT 信号docker stop 发送的是 SIGTERM 信号,都需要处理
#[cfg(target_family = "unix")]
pub async fn terminate() -> io::Result<()> {
use tokio::select;
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())?;
let mut int = signal::unix::signal(signal::unix::SignalKind::interrupt())?;
select! {
_ = term.recv() => Ok(()),
_ = int.recv() => Ok(()),
}
}

View File

@@ -1,7 +1,7 @@
use anyhow::Result;
static STATUS_MAX_RETRY: u32 = 0b100;
static STATUS_OK: u32 = 0b111;
pub(super) static STATUS_MAX_RETRY: u32 = 0b100;
pub(super) static STATUS_OK: u32 = 0b111;
pub static STATUS_COMPLETED: u32 = 1 << 31;
/// 用来表示下载的状态,不想写太多列了,所以仅使用一个 u32 表示。

0
web/build/.gitkeep Normal file
View File