diff --git a/crates/bili_sync/src/api/routes/mod.rs b/crates/bili_sync/src/api/routes/mod.rs index dd555b9..c3157e1 100644 --- a/crates/bili_sync/src/api/routes/mod.rs +++ b/crates/bili_sync/src/api/routes/mod.rs @@ -39,22 +39,29 @@ pub fn router() -> Router { ) } -/// 中间件:验证请求头中的 Authorization 是否与配置中的 auth_token 匹配 -pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result { +/// 中间件:使用 auth token 对请求进行身份验证 +pub async fn auth(mut headers: HeaderMap, request: Request, next: Next) -> Result { let config = VersionedConfig::get().load(); let token = config.auth_token.as_str(); if headers .get("Authorization") .and_then(|v| v.to_str().ok()) .is_some_and(|s| s == token) - || headers - .get("Sec-WebSocket-Protocol") - .and_then(|v| v.to_str().ok()) - .and_then(|s| BASE64_URL_SAFE_NO_PAD.decode(s).ok()) - .is_some_and(|s| s == token.as_bytes()) { return Ok(next.run(request).await); } + if let Some(protocol) = headers.remove("Sec-WebSocket-Protocol") { + if protocol + .to_str() + .ok() + .and_then(|s| BASE64_URL_SAFE_NO_PAD.decode(s).ok()) + .is_some_and(|s| s == token.as_bytes()) + { + let mut resp = next.run(request).await; + resp.headers_mut().insert("Sec-WebSocket-Protocol", protocol); + return Ok(resp); + } + } Ok(ApiResponse::<()>::unauthorized("auth token does not match").into_response()) }