diff --git a/crates/bili_sync/src/api/auth.rs b/crates/bili_sync/src/api/auth.rs index 66f978b..f590b32 100644 --- a/crates/bili_sync/src/api/auth.rs +++ b/crates/bili_sync/src/api/auth.rs @@ -1,16 +1,17 @@ use axum::extract::Request; use axum::http::HeaderMap; use axum::middleware::Next; -use axum::response::Response; +use axum::response::{IntoResponse, Response}; use reqwest::StatusCode; use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}; use utoipa::Modify; +use crate::api::wrapper::ApiResponse; use crate::config::CONFIG; pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result { if request.uri().path().starts_with("/api/") && get_token(&headers) != CONFIG.auth_token { - return Err(StatusCode::UNAUTHORIZED); + return Ok(ApiResponse::unauthorized(()).into_response()); } Ok(next.run(request).await) } @@ -22,7 +23,7 @@ fn get_token(headers: &HeaderMap) -> Option { .map(Into::into) } -pub struct OpenAPIAuth; +pub(super) struct OpenAPIAuth; impl Modify for OpenAPIAuth { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { diff --git a/crates/bili_sync/src/api/error.rs b/crates/bili_sync/src/api/error.rs index e8fcf55..adacbb5 100644 --- a/crates/bili_sync/src/api/error.rs +++ b/crates/bili_sync/src/api/error.rs @@ -1,24 +1,7 @@ -use anyhow::Error; -use axum::response::IntoResponse; -use reqwest::StatusCode; +use thiserror::Error; -pub struct ApiError(Error); - -impl IntoResponse for ApiError { - fn into_response(self) -> axum::response::Response { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Internal Server Error: {}", self.0), - ) - .into_response() - } -} - -impl From for ApiError -where - E: Into, -{ - fn from(value: E) -> Self { - Self(value.into()) - } +#[derive(Error, Debug)] +pub enum InnerApiError { + #[error("Primary key not found: {0}")] + NotFound(i32), } diff --git a/crates/bili_sync/src/api/handler.rs b/crates/bili_sync/src/api/handler.rs index 29607e4..e8598b1 100644 --- a/crates/bili_sync/src/api/handler.rs +++ b/crates/bili_sync/src/api/handler.rs @@ -2,20 +2,26 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use axum::extract::{Extension, Path, Query}; -use axum::Json; use bili_sync_entity::*; -use bili_sync_migration::Expr; -use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect}; +use bili_sync_migration::{Expr, OnConflict}; +use sea_orm::{ + ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set, + TransactionTrait, Unchanged, +}; use utoipa::OpenApi; use crate::api::auth::OpenAPIAuth; -use crate::api::error::ApiError; +use crate::api::error::InnerApiError; use crate::api::request::VideosRequest; -use crate::api::response::{PageInfo, VideoInfo, VideoResponse, VideoSource, VideoSourcesResponse, VideosResponse}; +use crate::api::response::{ + PageInfo, ResetVideoResponse, VideoInfo, VideoResponse, VideoSource, VideoSourcesResponse, VideosResponse, +}; +use crate::api::wrapper::{ApiError, ApiResponse}; +use crate::utils::status::{PageStatus, VideoStatus}; #[derive(OpenApi)] #[openapi( - paths(get_video_sources, get_videos, get_video), + paths(get_video_sources, get_videos, get_video, reset_video), modifiers(&OpenAPIAuth), security( ("Token" = []), @@ -28,13 +34,13 @@ pub struct ApiDoc; get, path = "/api/video-sources", responses( - (status = 200, body = VideoSourcesResponse), + (status = 200, body = ApiResponse), ) )] pub async fn get_video_sources( Extension(db): Extension>, -) -> Result, ApiError> { - Ok(Json(VideoSourcesResponse { +) -> Result, ApiError> { + Ok(ApiResponse::ok(VideoSourcesResponse { collection: collection::Entity::find() .select_only() .columns([collection::Column::Id, collection::Column::Name]) @@ -72,13 +78,13 @@ pub async fn get_video_sources( VideosRequest, ), responses( - (status = 200, body = VideosResponse), + (status = 200, body = ApiResponse), ) )] pub async fn get_videos( Extension(db): Extension>, Query(params): Query, -) -> Result, ApiError> { +) -> Result, ApiError> { let mut query = video::Entity::find(); for (field, column) in [ (params.collection, video::Column::CollectionId), @@ -99,13 +105,23 @@ pub async fn get_videos( } else { (1, 10) }; - Ok(Json(VideosResponse { + Ok(ApiResponse::ok(VideosResponse { videos: query .order_by_desc(video::Column::Id) - .into_partial_model::() + .select_only() + .columns([ + video::Column::Id, + video::Column::Name, + video::Column::UpperName, + video::Column::DownloadStatus, + ]) + .into_tuple::<(i32, String, String, u32)>() .paginate(db.as_ref(), page_size) .fetch_page(page) - .await?, + .await? + .into_iter() + .map(VideoInfo::from) + .collect(), total_count, })) } @@ -115,28 +131,127 @@ pub async fn get_videos( get, path = "/api/videos/{id}", responses( - (status = 200, body = VideoResponse), + (status = 200, body = ApiResponse), ) )] pub async fn get_video( Path(id): Path, Extension(db): Extension>, -) -> Result, ApiError> { +) -> Result, ApiError> { let video_info = video::Entity::find_by_id(id) - .into_partial_model::() + .select_only() + .columns([ + video::Column::Id, + video::Column::Name, + video::Column::UpperName, + video::Column::DownloadStatus, + ]) + .into_tuple::<(i32, String, String, u32)>() .one(db.as_ref()) - .await?; + .await? + .map(VideoInfo::from); let Some(video_info) = video_info else { - return Err(anyhow!("视频不存在").into()); + return Err(InnerApiError::NotFound(id).into()); }; let pages = page::Entity::find() .filter(page::Column::VideoId.eq(id)) .order_by_asc(page::Column::Pid) - .into_partial_model::() + .select_only() + .columns([ + page::Column::Id, + page::Column::Pid, + page::Column::Name, + page::Column::DownloadStatus, + ]) + .into_tuple::<(i32, i32, String, u32)>() .all(db.as_ref()) - .await?; - Ok(Json(VideoResponse { + .await? + .into_iter() + .map(PageInfo::from) + .collect(); + Ok(ApiResponse::ok(VideoResponse { video: video_info, pages, })) } + +/// 将某个视频与其所有分页的失败状态清空为未下载状态,这样在下次下载任务中会触发重试 +#[utoipa::path( + post, + path = "/api/videos/{id}/reset", + responses( + (status = 200, body = ApiResponse ), + ) +)] +pub async fn reset_video( + Path(id): Path, + Extension(db): Extension>, +) -> Result, ApiError> { + let txn = db.begin().await?; + let video_status: Option = video::Entity::find_by_id(id) + .select_only() + .column(video::Column::DownloadStatus) + .into_tuple() + .one(&txn) + .await?; + let Some(video_status) = video_status else { + return Err(anyhow!(InnerApiError::NotFound(id)).into()); + }; + let resetted_pages_tuple: Vec<(i32, u32)> = page::Entity::find() + .filter(page::Column::VideoId.eq(id)) + .select_only() + .columns([page::Column::Id, page::Column::DownloadStatus]) + .into_tuple::<(i32, u32)>() + .all(&txn) + .await? + .into_iter() + .filter_map(|(id, page_status)| { + let mut page_status = PageStatus::from(page_status); + if page_status.reset_failed() { + Some((id, page_status.into())) + } else { + None + } + }) + .collect(); + let mut video_status = VideoStatus::from(video_status); + let mut should_update_video = video_status.reset_failed(); + if !resetted_pages_tuple.is_empty() { + // 视频状态标志的第 5 位表示是否有分 P 下载失败,如果有需要重置的分页,需要同时重置视频的该状态 + video_status.set(4, 0); + should_update_video = true; + } + if should_update_video { + video::Entity::update(video::ActiveModel { + id: Unchanged(id), + download_status: Set(video_status.into()), + ..Default::default() + }) + .exec(&txn) + .await?; + } + let resetted_pages: Vec<_> = resetted_pages_tuple + .iter() + .map(|(id, page_status)| page::ActiveModel { + id: Unchanged(*id), + download_status: Set(*page_status), + ..Default::default() + }) + .collect(); + for page_trunk in resetted_pages.chunks(50) { + page::Entity::insert_many(page_trunk.to_vec()) + .on_conflict( + OnConflict::column(page::Column::Id) + .update_column(page::Column::DownloadStatus) + .to_owned(), + ) + .exec(&txn) + .await?; + } + txn.commit().await?; + Ok(ApiResponse::ok(ResetVideoResponse { + resetted: should_update_video, + video: id, + pages: resetted_pages_tuple.into_iter().map(|(id, _)| id).collect(), + })) +} diff --git a/crates/bili_sync/src/api/mod.rs b/crates/bili_sync/src/api/mod.rs index 0bf83a6..aa0e600 100644 --- a/crates/bili_sync/src/api/mod.rs +++ b/crates/bili_sync/src/api/mod.rs @@ -1,6 +1,7 @@ pub mod auth; -pub mod error; pub mod handler; +mod error; mod request; mod response; +mod wrapper; diff --git a/crates/bili_sync/src/api/response.rs b/crates/bili_sync/src/api/response.rs index 8c34eb4..e80abd4 100644 --- a/crates/bili_sync/src/api/response.rs +++ b/crates/bili_sync/src/api/response.rs @@ -1,8 +1,9 @@ -use bili_sync_entity::*; -use sea_orm::{DerivePartialModel, FromQueryResult}; +use sea_orm::FromQueryResult; use serde::Serialize; use utoipa::ToSchema; +use crate::utils::status::{PageStatus, VideoStatus}; + #[derive(Serialize, ToSchema)] pub struct VideoSourcesResponse { pub collection: Vec, @@ -23,24 +24,53 @@ pub struct VideoResponse { pub pages: Vec, } +#[derive(Serialize, ToSchema)] +pub struct ResetVideoResponse { + pub resetted: bool, + pub video: i32, + pub pages: Vec, +} + #[derive(FromQueryResult, Serialize, ToSchema)] pub struct VideoSource { id: i32, name: String, } -#[derive(DerivePartialModel, FromQueryResult, Serialize, ToSchema)] -#[sea_orm(entity = "page::Entity")] +#[derive(Serialize, ToSchema)] pub struct PageInfo { - id: i32, - pid: i32, - name: String, + pub id: i32, + pub pid: i32, + pub name: String, + pub download_status: [u32; 5], } -#[derive(DerivePartialModel, FromQueryResult, Serialize, ToSchema)] -#[sea_orm(entity = "video::Entity")] -pub struct VideoInfo { - id: i32, - name: String, - upper_name: String, +impl From<(i32, i32, String, u32)> for PageInfo { + fn from((id, pid, name, download_status): (i32, i32, String, u32)) -> Self { + Self { + id, + pid, + name, + download_status: PageStatus::from(download_status).into(), + } + } +} + +#[derive(Serialize, ToSchema)] +pub struct VideoInfo { + pub id: i32, + pub name: String, + pub upper_name: String, + pub download_status: [u32; 5], +} + +impl From<(i32, String, String, u32)> for VideoInfo { + fn from((id, name, upper_name, download_status): (i32, String, String, u32)) -> Self { + Self { + id, + name, + upper_name, + download_status: VideoStatus::from(download_status).into(), + } + } } diff --git a/crates/bili_sync/src/api/wrapper.rs b/crates/bili_sync/src/api/wrapper.rs new file mode 100644 index 0000000..27a2766 --- /dev/null +++ b/crates/bili_sync/src/api/wrapper.rs @@ -0,0 +1,64 @@ +use anyhow::Error; +use axum::response::IntoResponse; +use axum::Json; +use reqwest::StatusCode; +use serde::Serialize; +use utoipa::ToSchema; + +use crate::api::error::InnerApiError; + +#[derive(ToSchema, Serialize)] +pub struct ApiResponse { + status_code: u16, + data: T, +} + +impl ApiResponse { + pub fn ok(data: T) -> Self { + Self { status_code: 200, data } + } + + pub fn unauthorized(data: T) -> Self { + Self { status_code: 401, data } + } + + pub fn not_found(data: T) -> Self { + Self { status_code: 404, data } + } + + pub fn internal_server_error(data: T) -> Self { + Self { status_code: 500, data } + } +} + +impl IntoResponse for ApiResponse { + fn into_response(self) -> axum::response::Response { + ( + StatusCode::from_u16(self.status_code).expect("invalid Http Status Code"), + Json(self), + ) + .into_response() + } +} + +pub struct ApiError(Error); + +impl From for ApiError +where + E: Into, +{ + fn from(value: E) -> Self { + Self(value.into()) + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> axum::response::Response { + if let Some(inner_error) = self.0.downcast_ref::() { + match inner_error { + InnerApiError::NotFound(_) => return ApiResponse::not_found(self.0.to_string()).into_response(), + } + } + ApiResponse::internal_server_error(self.0.to_string()).into_response() + } +} diff --git a/crates/bili_sync/src/task/http_server.rs b/crates/bili_sync/src/task/http_server.rs index bcf7621..a13b694 100644 --- a/crates/bili_sync/src/task/http_server.rs +++ b/crates/bili_sync/src/task/http_server.rs @@ -4,7 +4,7 @@ use anyhow::{Context, Result}; use axum::extract::Request; use axum::http::{header, Uri}; use axum::response::IntoResponse; -use axum::routing::get; +use axum::routing::{get, post}; use axum::{middleware, Extension, Router, ServiceExt}; use reqwest::StatusCode; use rust_embed::Embed; @@ -13,7 +13,7 @@ use utoipa::OpenApi; use utoipa_swagger_ui::{Config, SwaggerUi}; use crate::api::auth; -use crate::api::handler::{get_video, get_video_sources, get_videos, ApiDoc}; +use crate::api::handler::{get_video, get_video_sources, get_videos, reset_video, ApiDoc}; use crate::config::CONFIG; #[derive(Embed)] @@ -24,7 +24,8 @@ pub async fn http_server(database_connection: Arc) -> Result let app = Router::new() .route("/api/video-sources", get(get_video_sources)) .route("/api/videos", get(get_videos)) - .route("/api/video/{id}", get(get_video)) + .route("/api/videos/{id}", get(get_video)) + .route("/api/videos/{id}/reset", post(reset_video)) .merge( SwaggerUi::new("/swagger-ui/") .url("/api-docs/openapi.json", ApiDoc::openapi()) diff --git a/crates/bili_sync/src/utils/status.rs b/crates/bili_sync/src/utils/status.rs index 8d19331..2efcf93 100644 --- a/crates/bili_sync/src/utils/status.rs +++ b/crates/bili_sync/src/utils/status.rs @@ -28,6 +28,33 @@ impl Status { result } + /// 重置所有失败的状态,将状态设置为 0b000,返回值表示是否有状态被重置 + pub fn reset_failed(&mut self) -> bool { + let mut resetted = false; + for i in 0..N { + let status = self.get_status(i); + if !(status < STATUS_MAX_RETRY || status == STATUS_OK) { + self.set_status(i, 0); + resetted = true; + } + } + if resetted { + self.set_completed(false); + } + resetted + } + + /// 覆盖某个子任务的状态 + pub fn set(&mut self, offset: usize, status: u32) { + assert!(status < 0b1000, "status should be less than 0b1000"); + self.set_status(offset, status); + if self.should_run().into_iter().all(|x| !x) { + self.set_completed(true); + } else { + self.set_completed(false); + } + } + /// 根据任务结果更新状态,任务结果是一个 Result 数组,需要与子任务一一对应 /// 如果所有子任务都已经完成,那么打上最高位的完成标记 pub fn update_status(&mut self, result: &[Result<()>]) { @@ -114,6 +141,7 @@ impl From<[u32; N]> for Status { fn from(status: [u32; N]) -> Self { let mut result = Status::::default(); for (i, item) in status.iter().enumerate() { + assert!(*item < 0b1000, "status should be less than 0b1000"); result.set_status(i, *item); } if result.should_run().iter().all(|x| !x) { @@ -165,4 +193,19 @@ mod test { assert_eq!(<[u32; 3]>::from(status), *after); } } + + #[test] + fn test_status_reset_failed() { + let mut status = Status::<3>::from([3, 4, 7]); + assert!(status.reset_failed()); + assert!(!status.get_completed()); + assert_eq!(<[u32; 3]>::from(status), [3, 0, 7]); + } + + #[test] + fn test_status_set() { + let mut status = Status::<5>::from([3, 4, 7, 2, 3]); + status.set(4, 0); + assert_eq!(<[u32; 5]>::from(status), [3, 4, 7, 2, 0]); + } }