diff --git a/crates/bili_sync/src/downloader.rs b/crates/bili_sync/src/downloader.rs index 606c0c0..2d0dd8f 100644 --- a/crates/bili_sync/src/downloader.rs +++ b/crates/bili_sync/src/downloader.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use anyhow::{Context, Result, bail, ensure}; use async_tempfile::TempFile; use futures::TryStreamExt; -use reqwest::{Method, header}; +use reqwest::{Method, StatusCode, header}; use tokio::fs::{self}; use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::process::Command; @@ -30,7 +30,8 @@ impl Downloader { pub async fn fetch(&self, url: &str, path: &Path, concurrent_download: &ConcurrentDownloadLimit) -> Result<()> { let mut temp_file = TempFile::new().await?; - self.fetch_internal(url, &mut temp_file, concurrent_download).await?; + self.fetch_internal(url, &mut temp_file, false, concurrent_download) + .await?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).await?; } @@ -48,7 +49,7 @@ impl Downloader { path: &Path, concurrent_download: &ConcurrentDownloadLimit, ) -> Result<()> { - let temp_file = self.multi_fetch_internal(urls, concurrent_download).await?; + let temp_file = self.multi_fetch_internal(urls, true, concurrent_download).await?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).await?; } @@ -65,8 +66,8 @@ impl Downloader { concurrent_download: &ConcurrentDownloadLimit, ) -> Result<()> { let (video_temp_file, audio_temp_file) = tokio::try_join!( - self.multi_fetch_internal(video_urls, concurrent_download), - self.multi_fetch_internal(audio_urls, concurrent_download) + self.multi_fetch_internal(video_urls, true, concurrent_download), + self.multi_fetch_internal(audio_urls, true, concurrent_download) )?; let final_temp_file = TempFile::new().await?; let output = Command::new("ffmpeg") @@ -105,6 +106,7 @@ impl Downloader { async fn multi_fetch_internal( &self, urls: &[&str], + is_stream: bool, concurrent_download: &ConcurrentDownloadLimit, ) -> Result { if urls.is_empty() { @@ -112,7 +114,10 @@ impl Downloader { } let mut temp_file = TempFile::new().await?; for (idx, url) in urls.iter().enumerate() { - match self.fetch_internal(url, &mut temp_file, concurrent_download).await { + match self + .fetch_internal(url, &mut temp_file, is_stream, concurrent_download) + .await + { Ok(_) => return Ok(temp_file), Err(e) => { if idx == urls.len() - 1 { @@ -131,10 +136,11 @@ impl Downloader { &self, url: &str, file: &mut TempFile, + is_stream: bool, concurrent_download: &ConcurrentDownloadLimit, ) -> Result<()> { if concurrent_download.enable { - self.fetch_parallel(url, file, concurrent_download).await + self.fetch_parallel(url, file, is_stream, concurrent_download).await } else { self.fetch_serial(url, file).await } @@ -166,25 +172,46 @@ impl Downloader { &self, url: &str, file: &mut TempFile, + is_stream: bool, concurrent_download: &ConcurrentDownloadLimit, ) -> Result<()> { let (concurrency, threshold) = (concurrent_download.concurrency, concurrent_download.threshold); - // 有些 B 站视频 url GET 有内容但 HEAD 会返回 404,此处使用 bytes=0-0 的 GET 代替 HEAD 以获取文件大小 - let resp = self - .client - .request(Method::GET, url, None) - .header(header::RANGE, "bytes=0-0") - .send() - .await? - .error_for_status()?; - let file_size = resp.header_content_length().unwrap_or_default(); + let file_size = if is_stream { + // B 站视频、音频流存在 HEAD 为 404 但 GET 正常的情况,此处假设支持分块,直接使用携带 Range 头的 GET 请求探测 + let resp = self + .client + .request(Method::GET, url, None) + .header(header::RANGE, "bytes=0-0") + .send() + .await? + .error_for_status()?; + if resp.status() != StatusCode::PARTIAL_CONTENT { + return self.fetch_serial(url, file).await; + } + resp.header_file_size() + } else { + // 对于普通文件,直接使用常规的 HEAD 请求探测 + let resp = self + .client + .request(Method::HEAD, url, None) + .send() + .await? + .error_for_status()?; + if resp + .headers() + .get(header::ACCEPT_RANGES) + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none + .is_none_or(|v| v.to_str().unwrap_or_default() == "none") + { + return self.fetch_serial(url, file).await; + } + resp.header_content_length() + }; + let Some(file_size) = file_size else { + return self.fetch_serial(url, file).await; + }; let chunk_size = file_size / concurrency as u64; - if resp - .headers() - .get(header::ACCEPT_RANGES) - .is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none - || chunk_size < threshold - { + if chunk_size < threshold { return self.fetch_serial(url, file).await; } file.set_len(file_size).await?; @@ -238,7 +265,10 @@ impl Downloader { /// reqwest.content_length() 居然指的是 body_size 而非 content-length header,没办法自己实现一下 /// https://github.com/seanmonstar/reqwest/issues/1814 trait ResponseExt { + /// 获取 Content-Length 头的值 fn header_content_length(&self) -> Option; + /// 获取 Content-Range 头中的文件总大小部分 + fn header_file_size(&self) -> Option; } impl ResponseExt for reqwest::Response { @@ -248,6 +278,17 @@ impl ResponseExt for reqwest::Response { .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) } + + fn header_file_size(&self) -> Option { + self.headers() + .get(header::CONTENT_RANGE) + .and_then(|v| v.to_str().ok()) + .and_then(|s| { + // Content-Range: bytes 0-0/800946 + s.rsplit_once('/') + }) + .and_then(|(_, size_str)| size_str.parse::().ok()) + } } #[cfg(test)] @@ -262,12 +303,12 @@ mod tests { use crate::downloader::Downloader; #[ignore = "only for manual test"] - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_parse_and_download_video() -> Result<()> { VersionedConfig::init_for_test(&setup_database(Path::new("./test.sqlite")).await?).await?; let config = VersionedConfig::get().read(); let client = BiliClient::new(); - let video = Video::new(&client, "BV14oCrBqEd2".to_owned(), &config.credential); + let video = Video::new(&client, "BV1QJmaYKEv4".to_owned(), &config.credential); let pages = video.get_pages().await.expect("failed to get pages"); let first_page = pages.into_iter().next().expect("no page found"); let mut page_analyzer = video