feat: 支持单个文件的并发下载 (#343)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-05-30 02:19:23 +08:00
committed by GitHub
parent 6383730706
commit a9f604a07d
2 changed files with 132 additions and 13 deletions

View File

@@ -31,6 +31,25 @@ pub struct ConcurrentLimit {
pub video: usize,
pub page: usize,
pub rate_limit: Option<RateLimit>,
#[serde(default)]
pub download: ConcurrentDownloadLimit,
}
#[derive(Serialize, Deserialize)]
pub struct ConcurrentDownloadLimit {
pub enable: bool,
pub concurrency: usize,
pub threshold: u64,
}
impl Default for ConcurrentDownloadLimit {
fn default() -> Self {
Self {
enable: true,
concurrency: 4,
threshold: 20 * (1 << 20), // 20 MB
}
}
}
#[derive(Serialize, Deserialize)]
@@ -49,6 +68,7 @@ impl Default for ConcurrentLimit {
limit: 4,
duration: 250,
}),
download: ConcurrentDownloadLimit::default(),
}
}
}

View File

@@ -1,14 +1,18 @@
use core::str;
use std::io::SeekFrom;
use std::path::Path;
use std::sync::Arc;
use anyhow::{Context, Result, bail, ensure};
use futures::TryStreamExt;
use reqwest::Method;
use tokio::fs::{self, File};
use tokio::io::AsyncWriteExt;
use reqwest::{Method, header};
use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::task::JoinSet;
use tokio_util::io::StreamReader;
use crate::bilibili::Client;
use crate::config::CONFIG;
pub struct Downloader {
client: Client,
}
@@ -22,26 +26,106 @@ impl Downloader {
}
pub async fn fetch(&self, url: &str, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
if CONFIG.concurrent_limit.download.enable {
self.fetch_parallel(url, path).await
} else {
self.fetch_serial(url, path).await
}
let mut file = File::create(path).await?;
}
async fn fetch_serial(&self, url: &str, path: &Path) -> Result<()> {
let resp = self
.client
.request(Method::GET, url, None)
.send()
.await?
.error_for_status()?;
let expected = resp.content_length().unwrap_or_default();
let expected = resp.header_content_length();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let mut file = File::create(path).await?;
let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other));
let received = tokio::io::copy(&mut stream_reader, &mut file).await?;
file.flush().await?;
ensure!(
received >= expected,
"received {} bytes, expected {} bytes",
received,
expected
);
if let Some(expected) = expected {
ensure!(
received == expected,
"downloaded bytes mismatch: expected {}, got {}",
expected,
received
);
}
Ok(())
}
async fn fetch_parallel(&self, url: &str, path: &Path) -> Result<()> {
let resp = self
.client
.request(Method::HEAD, url, None)
.send()
.await?
.error_for_status()?;
let file_size = resp.header_content_length().unwrap_or_default();
let chunk_size = file_size / CONFIG.concurrent_limit.download.concurrency as u64;
if resp
.headers()
.get(header::ACCEPT_RANGES)
.is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
|| chunk_size < CONFIG.concurrent_limit.download.threshold
{
return self.fetch_serial(url, path).await;
}
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let file = File::create(path).await?;
file.set_len(file_size).await?;
drop(file);
let mut tasks = JoinSet::new();
let url = Arc::new(url.to_string());
let path = Arc::new(path.to_path_buf());
for i in 0..CONFIG.concurrent_limit.download.concurrency {
let start = i as u64 * chunk_size;
let end = if i == CONFIG.concurrent_limit.download.concurrency - 1 {
file_size
} else {
start + chunk_size
} - 1;
let (url_clone, path_clone, client_clone) = (url.clone(), path.clone(), self.client.clone());
tasks.spawn(async move {
let mut file = OpenOptions::new().write(true).open(path_clone.as_ref()).await?;
file.seek(SeekFrom::Start(start)).await?;
let range_header = format!("bytes={}-{}", start, end);
let resp = client_clone
.request(Method::GET, &url_clone, None)
.header(header::RANGE, &range_header)
.send()
.await?
.error_for_status()?;
if let Some(content_length) = resp.header_content_length() {
ensure!(
content_length == end - start + 1,
"content length mismatch: expected {}, got {}",
end - start + 1,
content_length
);
}
let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other));
let received = tokio::io::copy(&mut stream_reader, &mut file).await?;
file.flush().await?;
ensure!(
received == end - start + 1,
"downloaded bytes mismatch: expected {}, got {}",
end - start + 1,
received,
);
Ok(())
});
}
while let Some(res) = tasks.join_next().await {
res??;
}
Ok(())
}
@@ -83,3 +167,18 @@ impl Downloader {
Ok(())
}
}
/// reqwest.content_length() 居然指的是 body_size 而非 content-length header没办法自己实现一下
/// https://github.com/seanmonstar/reqwest/issues/1814
trait ResponseExt {
fn header_content_length(&self) -> Option<u64>;
}
impl ResponseExt for reqwest::Response {
fn header_content_length(&self) -> Option<u64> {
self.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
}