fix: 修复并行下载未正确触发的问题,根据文件是否为流做不同处理 (#586)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-12-31 11:52:38 +08:00
committed by GitHub
parent f24ee97b28
commit 0b5ae3d664

View File

@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::{Context, Result, bail, ensure};
use async_tempfile::TempFile;
use futures::TryStreamExt;
use reqwest::{Method, header};
use reqwest::{Method, StatusCode, header};
use tokio::fs::{self};
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::process::Command;
@@ -30,7 +30,8 @@ impl Downloader {
pub async fn fetch(&self, url: &str, path: &Path, concurrent_download: &ConcurrentDownloadLimit) -> Result<()> {
let mut temp_file = TempFile::new().await?;
self.fetch_internal(url, &mut temp_file, concurrent_download).await?;
self.fetch_internal(url, &mut temp_file, false, concurrent_download)
.await?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
@@ -48,7 +49,7 @@ impl Downloader {
path: &Path,
concurrent_download: &ConcurrentDownloadLimit,
) -> Result<()> {
let temp_file = self.multi_fetch_internal(urls, concurrent_download).await?;
let temp_file = self.multi_fetch_internal(urls, true, concurrent_download).await?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
@@ -65,8 +66,8 @@ impl Downloader {
concurrent_download: &ConcurrentDownloadLimit,
) -> Result<()> {
let (video_temp_file, audio_temp_file) = tokio::try_join!(
self.multi_fetch_internal(video_urls, concurrent_download),
self.multi_fetch_internal(audio_urls, concurrent_download)
self.multi_fetch_internal(video_urls, true, concurrent_download),
self.multi_fetch_internal(audio_urls, true, concurrent_download)
)?;
let final_temp_file = TempFile::new().await?;
let output = Command::new("ffmpeg")
@@ -105,6 +106,7 @@ impl Downloader {
async fn multi_fetch_internal(
&self,
urls: &[&str],
is_stream: bool,
concurrent_download: &ConcurrentDownloadLimit,
) -> Result<TempFile> {
if urls.is_empty() {
@@ -112,7 +114,10 @@ impl Downloader {
}
let mut temp_file = TempFile::new().await?;
for (idx, url) in urls.iter().enumerate() {
match self.fetch_internal(url, &mut temp_file, concurrent_download).await {
match self
.fetch_internal(url, &mut temp_file, is_stream, concurrent_download)
.await
{
Ok(_) => return Ok(temp_file),
Err(e) => {
if idx == urls.len() - 1 {
@@ -131,10 +136,11 @@ impl Downloader {
&self,
url: &str,
file: &mut TempFile,
is_stream: bool,
concurrent_download: &ConcurrentDownloadLimit,
) -> Result<()> {
if concurrent_download.enable {
self.fetch_parallel(url, file, concurrent_download).await
self.fetch_parallel(url, file, is_stream, concurrent_download).await
} else {
self.fetch_serial(url, file).await
}
@@ -166,25 +172,46 @@ impl Downloader {
&self,
url: &str,
file: &mut TempFile,
is_stream: bool,
concurrent_download: &ConcurrentDownloadLimit,
) -> Result<()> {
let (concurrency, threshold) = (concurrent_download.concurrency, concurrent_download.threshold);
// 有些 B 站视频 url GET 有内容但 HEAD 会返回 404此处使用 bytes=0-0 的 GET 代替 HEAD 以获取文件大小
let resp = self
.client
.request(Method::GET, url, None)
.header(header::RANGE, "bytes=0-0")
.send()
.await?
.error_for_status()?;
let file_size = resp.header_content_length().unwrap_or_default();
let file_size = if is_stream {
// B 站视频、音频流存在 HEAD 为 404 但 GET 正常的情况,此处假设支持分块,直接使用携带 Range 头的 GET 请求探测
let resp = self
.client
.request(Method::GET, url, None)
.header(header::RANGE, "bytes=0-0")
.send()
.await?
.error_for_status()?;
if resp.status() != StatusCode::PARTIAL_CONTENT {
return self.fetch_serial(url, file).await;
}
resp.header_file_size()
} else {
// 对于普通文件,直接使用常规的 HEAD 请求探测
let resp = self
.client
.request(Method::HEAD, url, None)
.send()
.await?
.error_for_status()?;
if resp
.headers()
.get(header::ACCEPT_RANGES)
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
.is_none_or(|v| v.to_str().unwrap_or_default() == "none")
{
return self.fetch_serial(url, file).await;
}
resp.header_content_length()
};
let Some(file_size) = file_size else {
return self.fetch_serial(url, file).await;
};
let chunk_size = file_size / concurrency as u64;
if resp
.headers()
.get(header::ACCEPT_RANGES)
.is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
|| chunk_size < threshold
{
if chunk_size < threshold {
return self.fetch_serial(url, file).await;
}
file.set_len(file_size).await?;
@@ -238,7 +265,10 @@ impl Downloader {
/// reqwest.content_length() 居然指的是 body_size 而非 content-length header没办法自己实现一下
/// https://github.com/seanmonstar/reqwest/issues/1814
trait ResponseExt {
/// 获取 Content-Length 头的值
fn header_content_length(&self) -> Option<u64>;
/// 获取 Content-Range 头中的文件总大小部分
fn header_file_size(&self) -> Option<u64>;
}
impl ResponseExt for reqwest::Response {
@@ -248,6 +278,17 @@ impl ResponseExt for reqwest::Response {
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
fn header_file_size(&self) -> Option<u64> {
self.headers()
.get(header::CONTENT_RANGE)
.and_then(|v| v.to_str().ok())
.and_then(|s| {
// Content-Range: bytes 0-0/800946
s.rsplit_once('/')
})
.and_then(|(_, size_str)| size_str.parse::<u64>().ok())
}
}
#[cfg(test)]
@@ -262,12 +303,12 @@ mod tests {
use crate::downloader::Downloader;
#[ignore = "only for manual test"]
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_parse_and_download_video() -> Result<()> {
VersionedConfig::init_for_test(&setup_database(Path::new("./test.sqlite")).await?).await?;
let config = VersionedConfig::get().read();
let client = BiliClient::new();
let video = Video::new(&client, "BV14oCrBqEd2".to_owned(), &config.credential);
let video = Video::new(&client, "BV1QJmaYKEv4".to_owned(), &config.credential);
let pages = video.get_pages().await.expect("failed to get pages");
let first_page = pages.into_iter().next().expect("no page found");
let mut page_analyzer = video