feat: 支持单个文件的并发下载 (#343)
This commit is contained in:
@@ -31,6 +31,25 @@ pub struct ConcurrentLimit {
|
||||
pub video: usize,
|
||||
pub page: usize,
|
||||
pub rate_limit: Option<RateLimit>,
|
||||
#[serde(default)]
|
||||
pub download: ConcurrentDownloadLimit,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ConcurrentDownloadLimit {
|
||||
pub enable: bool,
|
||||
pub concurrency: usize,
|
||||
pub threshold: u64,
|
||||
}
|
||||
|
||||
impl Default for ConcurrentDownloadLimit {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable: true,
|
||||
concurrency: 4,
|
||||
threshold: 20 * (1 << 20), // 20 MB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -49,6 +68,7 @@ impl Default for ConcurrentLimit {
|
||||
limit: 4,
|
||||
duration: 250,
|
||||
}),
|
||||
download: ConcurrentDownloadLimit::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
use core::str;
|
||||
use std::io::SeekFrom;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::Method;
|
||||
use tokio::fs::{self, File};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use reqwest::{Method, header};
|
||||
use tokio::fs::{self, File, OpenOptions};
|
||||
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
use crate::bilibili::Client;
|
||||
use crate::config::CONFIG;
|
||||
pub struct Downloader {
|
||||
client: Client,
|
||||
}
|
||||
@@ -22,26 +26,106 @@ impl Downloader {
|
||||
}
|
||||
|
||||
pub async fn fetch(&self, url: &str, path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
if CONFIG.concurrent_limit.download.enable {
|
||||
self.fetch_parallel(url, path).await
|
||||
} else {
|
||||
self.fetch_serial(url, path).await
|
||||
}
|
||||
let mut file = File::create(path).await?;
|
||||
}
|
||||
|
||||
async fn fetch_serial(&self, url: &str, path: &Path) -> Result<()> {
|
||||
let resp = self
|
||||
.client
|
||||
.request(Method::GET, url, None)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
let expected = resp.content_length().unwrap_or_default();
|
||||
let expected = resp.header_content_length();
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
let mut file = File::create(path).await?;
|
||||
let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other));
|
||||
let received = tokio::io::copy(&mut stream_reader, &mut file).await?;
|
||||
file.flush().await?;
|
||||
ensure!(
|
||||
received >= expected,
|
||||
"received {} bytes, expected {} bytes",
|
||||
received,
|
||||
expected
|
||||
);
|
||||
if let Some(expected) = expected {
|
||||
ensure!(
|
||||
received == expected,
|
||||
"downloaded bytes mismatch: expected {}, got {}",
|
||||
expected,
|
||||
received
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_parallel(&self, url: &str, path: &Path) -> Result<()> {
|
||||
let resp = self
|
||||
.client
|
||||
.request(Method::HEAD, url, None)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
let file_size = resp.header_content_length().unwrap_or_default();
|
||||
let chunk_size = file_size / CONFIG.concurrent_limit.download.concurrency as u64;
|
||||
if resp
|
||||
.headers()
|
||||
.get(header::ACCEPT_RANGES)
|
||||
.is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
|
||||
|| chunk_size < CONFIG.concurrent_limit.download.threshold
|
||||
{
|
||||
return self.fetch_serial(url, path).await;
|
||||
}
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
let file = File::create(path).await?;
|
||||
file.set_len(file_size).await?;
|
||||
drop(file);
|
||||
let mut tasks = JoinSet::new();
|
||||
let url = Arc::new(url.to_string());
|
||||
let path = Arc::new(path.to_path_buf());
|
||||
for i in 0..CONFIG.concurrent_limit.download.concurrency {
|
||||
let start = i as u64 * chunk_size;
|
||||
let end = if i == CONFIG.concurrent_limit.download.concurrency - 1 {
|
||||
file_size
|
||||
} else {
|
||||
start + chunk_size
|
||||
} - 1;
|
||||
let (url_clone, path_clone, client_clone) = (url.clone(), path.clone(), self.client.clone());
|
||||
tasks.spawn(async move {
|
||||
let mut file = OpenOptions::new().write(true).open(path_clone.as_ref()).await?;
|
||||
file.seek(SeekFrom::Start(start)).await?;
|
||||
let range_header = format!("bytes={}-{}", start, end);
|
||||
let resp = client_clone
|
||||
.request(Method::GET, &url_clone, None)
|
||||
.header(header::RANGE, &range_header)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
if let Some(content_length) = resp.header_content_length() {
|
||||
ensure!(
|
||||
content_length == end - start + 1,
|
||||
"content length mismatch: expected {}, got {}",
|
||||
end - start + 1,
|
||||
content_length
|
||||
);
|
||||
}
|
||||
let mut stream_reader = StreamReader::new(resp.bytes_stream().map_err(std::io::Error::other));
|
||||
let received = tokio::io::copy(&mut stream_reader, &mut file).await?;
|
||||
file.flush().await?;
|
||||
ensure!(
|
||||
received == end - start + 1,
|
||||
"downloaded bytes mismatch: expected {}, got {}",
|
||||
end - start + 1,
|
||||
received,
|
||||
);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
while let Some(res) = tasks.join_next().await {
|
||||
res??;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -83,3 +167,18 @@ impl Downloader {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// reqwest.content_length() 居然指的是 body_size 而非 content-length header,没办法自己实现一下
|
||||
/// https://github.com/seanmonstar/reqwest/issues/1814
|
||||
trait ResponseExt {
|
||||
fn header_content_length(&self) -> Option<u64>;
|
||||
}
|
||||
|
||||
impl ResponseExt for reqwest::Response {
|
||||
fn header_content_length(&self) -> Option<u64> {
|
||||
self.headers()
|
||||
.get(header::CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user