diff --git a/crates/bili_sync/src/config/item.rs b/crates/bili_sync/src/config/item.rs index dd4c31f..e2b877c 100644 --- a/crates/bili_sync/src/config/item.rs +++ b/crates/bili_sync/src/config/item.rs @@ -31,6 +31,25 @@ pub struct ConcurrentLimit { pub video: usize, pub page: usize, pub rate_limit: Option, + #[serde(default)] + pub download: ConcurrentDownloadLimit, +} + +#[derive(Serialize, Deserialize)] +pub struct ConcurrentDownloadLimit { + pub enable: bool, + pub concurrency: usize, + pub threshold: u64, +} + +impl Default for ConcurrentDownloadLimit { + fn default() -> Self { + Self { + enable: true, + concurrency: 4, + threshold: 20 * (1 << 20), // 20 MB + } + } } #[derive(Serialize, Deserialize)] @@ -49,6 +68,7 @@ impl Default for ConcurrentLimit { limit: 4, duration: 250, }), + download: ConcurrentDownloadLimit::default(), } } } diff --git a/crates/bili_sync/src/downloader.rs b/crates/bili_sync/src/downloader.rs index 3594d6c..3c77647 100644 --- a/crates/bili_sync/src/downloader.rs +++ b/crates/bili_sync/src/downloader.rs @@ -1,14 +1,18 @@ use core::str; +use std::io::SeekFrom; use std::path::Path; +use std::sync::Arc; use anyhow::{Context, Result, bail, ensure}; use futures::TryStreamExt; -use reqwest::Method; -use tokio::fs::{self, File}; -use tokio::io::AsyncWriteExt; +use reqwest::{Method, header}; +use tokio::fs::{self, File, OpenOptions}; +use tokio::io::{AsyncSeekExt, AsyncWriteExt}; +use tokio::task::JoinSet; use tokio_util::io::StreamReader; use crate::bilibili::Client; +use crate::config::CONFIG; pub struct Downloader { client: Client, } @@ -22,26 +26,106 @@ impl Downloader { } pub async fn fetch(&self, url: &str, path: &Path) -> Result<()> { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent).await?; + if CONFIG.concurrent_limit.download.enable { + self.fetch_parallel(url, path).await + } else { + self.fetch_serial(url, path).await } - let mut file = File::create(path).await?; + } + + async fn fetch_serial(&self, url: &str, path: &Path) -> Result<()> { let resp = self .client .request(Method::GET, url, None) .send() .await? .error_for_status()?; - let expected = resp.content_length().unwrap_or_default(); + let expected = resp.header_content_length(); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await?; + } + let mut file = File::create(path).await?; let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other)); let received = tokio::io::copy(&mut stream_reader, &mut file).await?; file.flush().await?; - ensure!( - received >= expected, - "received {} bytes, expected {} bytes", - received, - expected - ); + if let Some(expected) = expected { + ensure!( + received == expected, + "downloaded bytes mismatch: expected {}, got {}", + expected, + received + ); + } + Ok(()) + } + + async fn fetch_parallel(&self, url: &str, path: &Path) -> Result<()> { + let resp = self + .client + .request(Method::HEAD, url, None) + .send() + .await? + .error_for_status()?; + let file_size = resp.header_content_length().unwrap_or_default(); + let chunk_size = file_size / CONFIG.concurrent_limit.download.concurrency as u64; + if resp + .headers() + .get(header::ACCEPT_RANGES) + .is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none + || chunk_size < CONFIG.concurrent_limit.download.threshold + { + return self.fetch_serial(url, path).await; + } + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await?; + } + let file = File::create(path).await?; + file.set_len(file_size).await?; + drop(file); + let mut tasks = JoinSet::new(); + let url = Arc::new(url.to_string()); + let path = Arc::new(path.to_path_buf()); + for i in 0..CONFIG.concurrent_limit.download.concurrency { + let start = i as u64 * chunk_size; + let end = if i == CONFIG.concurrent_limit.download.concurrency - 1 { + file_size + } else { + start + chunk_size + } - 1; + let (url_clone, path_clone, client_clone) = (url.clone(), path.clone(), self.client.clone()); + tasks.spawn(async move { + let mut file = OpenOptions::new().write(true).open(path_clone.as_ref()).await?; + file.seek(SeekFrom::Start(start)).await?; + let range_header = format!("bytes={}-{}", start, end); + let resp = client_clone + .request(Method::GET, &url_clone, None) + .header(header::RANGE, &range_header) + .send() + .await? + .error_for_status()?; + if let Some(content_length) = resp.header_content_length() { + ensure!( + content_length == end - start + 1, + "content length mismatch: expected {}, got {}", + end - start + 1, + content_length + ); + } + let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other)); + let received = tokio::io::copy(&mut stream_reader, &mut file).await?; + file.flush().await?; + ensure!( + received == end - start + 1, + "downloaded bytes mismatch: expected {}, got {}", + end - start + 1, + received, + ); + Ok(()) + }); + } + while let Some(res) = tasks.join_next().await { + res??; + } Ok(()) } @@ -83,3 +167,18 @@ impl Downloader { Ok(()) } } + +/// reqwest.content_length() 居然指的是 body_size 而非 content-length header,没办法自己实现一下 +/// https://github.com/seanmonstar/reqwest/issues/1814 +trait ResponseExt { + fn header_content_length(&self) -> Option; +} + +impl ResponseExt for reqwest::Response { + fn header_content_length(&self) -> Option { + self.headers() + .get(header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + } +}