feat: 加入带有详细类型注释的 swagger 文档 (#257)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-02-18 01:55:54 +08:00
committed by GitHub
parent 1467c262a1
commit c995b3bf72
10 changed files with 293 additions and 102 deletions

View File

@@ -3,11 +3,13 @@ use axum::http::HeaderMap;
use axum::middleware::Next;
use axum::response::Response;
use reqwest::StatusCode;
use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme};
use utoipa::Modify;
use crate::config::CONFIG;
pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Response, StatusCode> {
if request.uri().path().starts_with("/api") && get_token(&headers) != CONFIG.auth_token {
if request.uri().path().starts_with("/api/") && get_token(&headers) != CONFIG.auth_token {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
@@ -19,3 +21,19 @@ fn get_token(headers: &HeaderMap) -> Option<String> {
.and_then(|v| v.to_str().ok())
.map(Into::into)
}
pub struct OpenAPIAuth;
impl Modify for OpenAPIAuth {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
if let Some(schema) = openapi.components.as_mut() {
schema.add_security_scheme(
"Token",
SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::with_description(
"Authorization",
"与配置文件中的 auth_token 相同",
))),
);
}
}
}

View File

@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
@@ -7,71 +6,100 @@ use axum::Json;
use bili_sync_entity::*;
use bili_sync_migration::Expr;
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect};
use utoipa::OpenApi;
use crate::api::auth::OpenAPIAuth;
use crate::api::error::ApiError;
use crate::api::payload::{PageInfo, VideoDetail, VideoInfo, VideoList, VideoListModel, VideoListModelItem};
use crate::api::request::VideosRequest;
use crate::api::response::{PageInfo, VideoInfo, VideoResponse, VideoSource, VideoSourcesResponse, VideosResponse};
/// 列出所有视频列表
pub async fn get_video_list_models(
#[derive(OpenApi)]
#[openapi(
paths(get_video_sources, get_videos, get_video),
modifiers(&OpenAPIAuth),
security(
("Token" = []),
)
)]
pub struct ApiDoc;
/// 列出所有视频来源
#[utoipa::path(
get,
path = "/api/video-sources",
responses(
(status = 200, body = VideoSourcesResponse),
)
)]
pub async fn get_video_sources(
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoListModel>, ApiError> {
Ok(Json(VideoListModel {
) -> Result<Json<VideoSourcesResponse>, ApiError> {
Ok(Json(VideoSourcesResponse {
collection: collection::Entity::find()
.select_only()
.columns([collection::Column::Id, collection::Column::Name])
.into_model::<VideoListModelItem>()
.into_model::<VideoSource>()
.all(db.as_ref())
.await?,
favorite: favorite::Entity::find()
.select_only()
.columns([favorite::Column::Id, favorite::Column::Name])
.into_model::<VideoListModelItem>()
.into_model::<VideoSource>()
.all(db.as_ref())
.await?,
submission: submission::Entity::find()
.select_only()
.column(submission::Column::Id)
.column_as(submission::Column::UpperName, "name")
.into_model::<VideoListModelItem>()
.into_model::<VideoSource>()
.all(db.as_ref())
.await?,
watch_later: watch_later::Entity::find()
.select_only()
.column(watch_later::Column::Id)
.column_as(Expr::value("稍后再看"), "name")
.into_model::<VideoListModelItem>()
.into_model::<VideoSource>()
.all(db.as_ref())
.await?,
}))
}
/// 列出所有视频的基本信息支持根据视频列表筛选,支持分页
pub async fn list_videos(
/// 列出视频的基本信息支持根据视频来源筛选、名称查找和分页
#[utoipa::path(
get,
path = "/api/videos",
params(
VideosRequest,
),
responses(
(status = 200, body = VideosResponse),
)
)]
pub async fn get_videos(
Extension(db): Extension<Arc<DatabaseConnection>>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<VideoList>, ApiError> {
Query(params): Query<VideosRequest>,
) -> Result<Json<VideosResponse>, ApiError> {
let mut query = video::Entity::find();
for (query_key, filter_column) in [
("collection", video::Column::CollectionId),
("favorite", video::Column::FavoriteId),
("submission", video::Column::SubmissionId),
("watch_later", video::Column::WatchLaterId),
for (field, column) in [
(params.collection, video::Column::CollectionId),
(params.favorite, video::Column::FavoriteId),
(params.submission, video::Column::SubmissionId),
(params.watch_later, video::Column::WatchLaterId),
] {
if let Some(value) = params.get(query_key) {
query = query.filter(filter_column.eq(value));
break;
if let Some(id) = field {
query = query.filter(column.eq(id));
}
}
if let Some(query_word) = params.get("q") {
if let Some(query_word) = params.query {
query = query.filter(video::Column::Name.contains(query_word));
}
let total_count = query.clone().count(db.as_ref()).await?;
let (page, page_size) = if let (Some(page), Some(page_size)) = (params.get("page"), params.get("page_size")) {
(page.parse::<u64>()?, page_size.parse::<u64>()?)
let (page, page_size) = if let (Some(page), Some(page_size)) = (params.page, params.page_size) {
(page, page_size)
} else {
(1, 10)
};
Ok(Json(VideoList {
Ok(Json(VideosResponse {
videos: query
.order_by_desc(video::Column::Id)
.into_partial_model::<VideoInfo>()
@@ -82,11 +110,18 @@ pub async fn list_videos(
}))
}
/// 根据 id 获取视频详细信息,包括关联的所有 page
/// 获取视频详细信息,包括关联的所有 page
#[utoipa::path(
get,
path = "/api/videos/{id}",
responses(
(status = 200, body = VideoResponse),
)
)]
pub async fn get_video(
Path(id): Path<i32>,
Extension(db): Extension<Arc<DatabaseConnection>>,
) -> Result<Json<VideoDetail>, ApiError> {
) -> Result<Json<VideoResponse>, ApiError> {
let video_info = video::Entity::find_by_id(id)
.into_partial_model::<VideoInfo>()
.one(db.as_ref())
@@ -100,7 +135,7 @@ pub async fn get_video(
.into_partial_model::<PageInfo>()
.all(db.as_ref())
.await?;
Ok(Json(VideoDetail {
Ok(Json(VideoResponse {
video: video_info,
pages,
}))

View File

@@ -1,4 +1,6 @@
pub mod auth;
pub mod error;
pub mod handler;
pub mod payload;
mod request;
mod response;

View File

@@ -1,45 +0,0 @@
use bili_sync_entity::*;
use sea_orm::{DerivePartialModel, FromQueryResult};
use serde::Serialize;
#[derive(FromQueryResult, Serialize)]
pub struct VideoListModelItem {
id: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoListModel {
pub collection: Vec<VideoListModelItem>,
pub favorite: Vec<VideoListModelItem>,
pub submission: Vec<VideoListModelItem>,
pub watch_later: Vec<VideoListModelItem>,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "video::Entity")]
pub struct VideoInfo {
id: i32,
name: String,
upper_name: String,
}
#[derive(Serialize)]
pub struct VideoList {
pub videos: Vec<VideoInfo>,
pub total_count: u64,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize)]
#[sea_orm(entity = "page::Entity")]
pub struct PageInfo {
id: i32,
pid: i32,
name: String,
}
#[derive(Serialize)]
pub struct VideoDetail {
pub video: VideoInfo,
pub pages: Vec<PageInfo>,
}

View File

@@ -0,0 +1,13 @@
use serde::Deserialize;
use utoipa::IntoParams;
#[derive(Deserialize, IntoParams)]
pub struct VideosRequest {
pub collection: Option<i32>,
pub favorite: Option<i32>,
pub submission: Option<i32>,
pub watch_later: Option<i32>,
pub query: Option<String>,
pub page: Option<u64>,
pub page_size: Option<u64>,
}

View File

@@ -0,0 +1,46 @@
use bili_sync_entity::*;
use sea_orm::{DerivePartialModel, FromQueryResult};
use serde::Serialize;
use utoipa::ToSchema;
#[derive(Serialize, ToSchema)]
pub struct VideoSourcesResponse {
pub collection: Vec<VideoSource>,
pub favorite: Vec<VideoSource>,
pub submission: Vec<VideoSource>,
pub watch_later: Vec<VideoSource>,
}
#[derive(Serialize, ToSchema)]
pub struct VideosResponse {
pub videos: Vec<VideoInfo>,
pub total_count: u64,
}
#[derive(Serialize, ToSchema)]
pub struct VideoResponse {
pub video: VideoInfo,
pub pages: Vec<PageInfo>,
}
#[derive(FromQueryResult, Serialize, ToSchema)]
pub struct VideoSource {
id: i32,
name: String,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize, ToSchema)]
#[sea_orm(entity = "page::Entity")]
pub struct PageInfo {
id: i32,
pid: i32,
name: String,
}
#[derive(DerivePartialModel, FromQueryResult, Serialize, ToSchema)]
#[sea_orm(entity = "video::Entity")]
pub struct VideoInfo {
id: i32,
name: String,
upper_name: String,
}

View File

@@ -9,11 +9,11 @@ use axum::{middleware, Extension, Router, ServiceExt};
use reqwest::StatusCode;
use rust_embed::Embed;
use sea_orm::DatabaseConnection;
use tower::Layer;
use tower_http::normalize_path::NormalizePathLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::{Config, SwaggerUi};
use crate::api::auth;
use crate::api::handler::{get_video, get_video_list_models, list_videos};
use crate::api::handler::{get_video, get_video_sources, get_videos, ApiDoc};
use crate::config::CONFIG;
#[derive(Embed)]
@@ -22,13 +22,22 @@ struct Asset;
pub async fn http_server(database_connection: Arc<DatabaseConnection>) -> Result<()> {
let app = Router::new()
.route("/api/videos", get(list_videos))
.route("/api/videos/{video_id}", get(get_video))
.route("/api/video-list-models", get(get_video_list_models))
.route("/api/video-sources", get(get_video_sources))
.route("/api/videos", get(get_videos))
.route("/api/video/{id}", get(get_video))
.merge(
SwaggerUi::new("/swagger-ui/")
.url("/api-docs/openapi.json", ApiDoc::openapi())
.config(
Config::default()
.try_it_out_enabled(true)
.persist_authorization(true)
.validator_url("none"),
),
)
.fallback_service(get(frontend_files))
.layer(Extension(database_connection))
.layer(middleware::from_fn(auth::auth));
let app = NormalizePathLayer::trim_trailing_slash().layer(app);
let listener = tokio::net::TcpListener::bind(&CONFIG.bind_address)
.await
.context("bind address failed")?;