fix: 修复并行下载未正确触发的问题,根据文件是否为流做不同处理 (#586)
This commit is contained in:
@@ -6,7 +6,7 @@ use std::sync::Arc;
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
use async_tempfile::TempFile;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::{Method, header};
|
||||
use reqwest::{Method, StatusCode, header};
|
||||
use tokio::fs::{self};
|
||||
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
|
||||
use tokio::process::Command;
|
||||
@@ -30,7 +30,8 @@ impl Downloader {
|
||||
|
||||
pub async fn fetch(&self, url: &str, path: &Path, concurrent_download: &ConcurrentDownloadLimit) -> Result<()> {
|
||||
let mut temp_file = TempFile::new().await?;
|
||||
self.fetch_internal(url, &mut temp_file, concurrent_download).await?;
|
||||
self.fetch_internal(url, &mut temp_file, false, concurrent_download)
|
||||
.await?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
@@ -48,7 +49,7 @@ impl Downloader {
|
||||
path: &Path,
|
||||
concurrent_download: &ConcurrentDownloadLimit,
|
||||
) -> Result<()> {
|
||||
let temp_file = self.multi_fetch_internal(urls, concurrent_download).await?;
|
||||
let temp_file = self.multi_fetch_internal(urls, true, concurrent_download).await?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await?;
|
||||
}
|
||||
@@ -65,8 +66,8 @@ impl Downloader {
|
||||
concurrent_download: &ConcurrentDownloadLimit,
|
||||
) -> Result<()> {
|
||||
let (video_temp_file, audio_temp_file) = tokio::try_join!(
|
||||
self.multi_fetch_internal(video_urls, concurrent_download),
|
||||
self.multi_fetch_internal(audio_urls, concurrent_download)
|
||||
self.multi_fetch_internal(video_urls, true, concurrent_download),
|
||||
self.multi_fetch_internal(audio_urls, true, concurrent_download)
|
||||
)?;
|
||||
let final_temp_file = TempFile::new().await?;
|
||||
let output = Command::new("ffmpeg")
|
||||
@@ -105,6 +106,7 @@ impl Downloader {
|
||||
async fn multi_fetch_internal(
|
||||
&self,
|
||||
urls: &[&str],
|
||||
is_stream: bool,
|
||||
concurrent_download: &ConcurrentDownloadLimit,
|
||||
) -> Result<TempFile> {
|
||||
if urls.is_empty() {
|
||||
@@ -112,7 +114,10 @@ impl Downloader {
|
||||
}
|
||||
let mut temp_file = TempFile::new().await?;
|
||||
for (idx, url) in urls.iter().enumerate() {
|
||||
match self.fetch_internal(url, &mut temp_file, concurrent_download).await {
|
||||
match self
|
||||
.fetch_internal(url, &mut temp_file, is_stream, concurrent_download)
|
||||
.await
|
||||
{
|
||||
Ok(_) => return Ok(temp_file),
|
||||
Err(e) => {
|
||||
if idx == urls.len() - 1 {
|
||||
@@ -131,10 +136,11 @@ impl Downloader {
|
||||
&self,
|
||||
url: &str,
|
||||
file: &mut TempFile,
|
||||
is_stream: bool,
|
||||
concurrent_download: &ConcurrentDownloadLimit,
|
||||
) -> Result<()> {
|
||||
if concurrent_download.enable {
|
||||
self.fetch_parallel(url, file, concurrent_download).await
|
||||
self.fetch_parallel(url, file, is_stream, concurrent_download).await
|
||||
} else {
|
||||
self.fetch_serial(url, file).await
|
||||
}
|
||||
@@ -166,25 +172,46 @@ impl Downloader {
|
||||
&self,
|
||||
url: &str,
|
||||
file: &mut TempFile,
|
||||
is_stream: bool,
|
||||
concurrent_download: &ConcurrentDownloadLimit,
|
||||
) -> Result<()> {
|
||||
let (concurrency, threshold) = (concurrent_download.concurrency, concurrent_download.threshold);
|
||||
// 有些 B 站视频 url GET 有内容但 HEAD 会返回 404,此处使用 bytes=0-0 的 GET 代替 HEAD 以获取文件大小
|
||||
let resp = self
|
||||
.client
|
||||
.request(Method::GET, url, None)
|
||||
.header(header::RANGE, "bytes=0-0")
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
let file_size = resp.header_content_length().unwrap_or_default();
|
||||
let file_size = if is_stream {
|
||||
// B 站视频、音频流存在 HEAD 为 404 但 GET 正常的情况,此处假设支持分块,直接使用携带 Range 头的 GET 请求探测
|
||||
let resp = self
|
||||
.client
|
||||
.request(Method::GET, url, None)
|
||||
.header(header::RANGE, "bytes=0-0")
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
if resp.status() != StatusCode::PARTIAL_CONTENT {
|
||||
return self.fetch_serial(url, file).await;
|
||||
}
|
||||
resp.header_file_size()
|
||||
} else {
|
||||
// 对于普通文件,直接使用常规的 HEAD 请求探测
|
||||
let resp = self
|
||||
.client
|
||||
.request(Method::HEAD, url, None)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
if resp
|
||||
.headers()
|
||||
.get(header::ACCEPT_RANGES)
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
|
||||
.is_none_or(|v| v.to_str().unwrap_or_default() == "none")
|
||||
{
|
||||
return self.fetch_serial(url, file).await;
|
||||
}
|
||||
resp.header_content_length()
|
||||
};
|
||||
let Some(file_size) = file_size else {
|
||||
return self.fetch_serial(url, file).await;
|
||||
};
|
||||
let chunk_size = file_size / concurrency as u64;
|
||||
if resp
|
||||
.headers()
|
||||
.get(header::ACCEPT_RANGES)
|
||||
.is_none_or(|v| v.to_str().unwrap_or_default() == "none") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept-Ranges#none
|
||||
|| chunk_size < threshold
|
||||
{
|
||||
if chunk_size < threshold {
|
||||
return self.fetch_serial(url, file).await;
|
||||
}
|
||||
file.set_len(file_size).await?;
|
||||
@@ -238,7 +265,10 @@ impl Downloader {
|
||||
/// reqwest.content_length() 居然指的是 body_size 而非 content-length header,没办法自己实现一下
|
||||
/// https://github.com/seanmonstar/reqwest/issues/1814
|
||||
trait ResponseExt {
|
||||
/// 获取 Content-Length 头的值
|
||||
fn header_content_length(&self) -> Option<u64>;
|
||||
/// 获取 Content-Range 头中的文件总大小部分
|
||||
fn header_file_size(&self) -> Option<u64>;
|
||||
}
|
||||
|
||||
impl ResponseExt for reqwest::Response {
|
||||
@@ -248,6 +278,17 @@ impl ResponseExt for reqwest::Response {
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
}
|
||||
|
||||
fn header_file_size(&self) -> Option<u64> {
|
||||
self.headers()
|
||||
.get(header::CONTENT_RANGE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| {
|
||||
// Content-Range: bytes 0-0/800946
|
||||
s.rsplit_once('/')
|
||||
})
|
||||
.and_then(|(_, size_str)| size_str.parse::<u64>().ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -262,12 +303,12 @@ mod tests {
|
||||
use crate::downloader::Downloader;
|
||||
|
||||
#[ignore = "only for manual test"]
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_parse_and_download_video() -> Result<()> {
|
||||
VersionedConfig::init_for_test(&setup_database(Path::new("./test.sqlite")).await?).await?;
|
||||
let config = VersionedConfig::get().read();
|
||||
let client = BiliClient::new();
|
||||
let video = Video::new(&client, "BV14oCrBqEd2".to_owned(), &config.credential);
|
||||
let video = Video::new(&client, "BV1QJmaYKEv4".to_owned(), &config.credential);
|
||||
let pages = video.get_pages().await.expect("failed to get pages");
|
||||
let first_page = pages.into_iter().next().expect("no page found");
|
||||
let mut page_analyzer = video
|
||||
|
||||
Reference in New Issue
Block a user