diff --git a/src/core/command.rs b/src/core/command.rs index 5b2ea85..6cd7e0a 100644 --- a/src/core/command.rs +++ b/src/core/command.rs @@ -3,7 +3,6 @@ use std::path::{Path, PathBuf}; use std::pin::Pin; use anyhow::{bail, Result}; -use dirs::config_dir; use entity::{favorite, page, video}; use filenamify::filenamify; use futures::stream::{FuturesOrdered, FuturesUnordered}; @@ -17,8 +16,11 @@ use tokio::fs; use tokio::sync::{Mutex, Semaphore}; use super::status::{PageStatus, VideoStatus}; -use super::utils::{unhandled_videos_pages, ModelWrapper, NFOMode, NFOSerializer, TEMPLATE}; +use super::utils::{ + unhandled_videos_pages, update_pages_model, update_videos_model, ModelWrapper, NFOMode, NFOSerializer, TEMPLATE, +}; use crate::bilibili::{BestStream, BiliClient, BiliError, FavoriteList, FilterOption, PageInfo, Video}; +use crate::config::CONFIG; use crate::core::utils::{ create_video_pages, create_videos, exist_labels, filter_unfilled_videos, handle_favorite_info, total_video_count, }; @@ -135,10 +137,15 @@ pub async fn download_unprocessed_videos( // 对于视频,允许五个同时下载(视频内还有分页、不同分页还有多种下载任务) let semaphore = Semaphore::new(5); let downloader = Downloader::default(); - let mut uppers_mutex: HashMap, Mutex<()>)> = HashMap::new(); + let mut uppers_mutex: HashMap, Mutex<()>)> = HashMap::new(); for (video_model, _) in &unhandled_videos_pages { uppers_mutex.insert(video_model.upper_id, (Mutex::new(()), Mutex::new(()))); } + let upper_path = { + let config = CONFIG.lock().unwrap(); + config.upper_path.clone() + }; + let upper_path = Path::new(&upper_path); let mut tasks = unhandled_videos_pages .into_iter() .map(|(video_model, pages_model)| { @@ -150,6 +157,7 @@ pub async fn download_unprocessed_videos( connection, &semaphore, &downloader, + upper_path, upper_mutex, ) }) @@ -169,18 +177,18 @@ pub async fn download_unprocessed_videos( } // 满十个就写入数据库 if models.len() == 10 { - video::Entity::insert_many(std::mem::replace(&mut models, Vec::with_capacity(10))) - .exec(connection) - .await?; + update_videos_model(std::mem::replace(&mut models, Vec::with_capacity(10)), connection).await?; } } if !models.is_empty() { - video::Entity::insert_many(models).exec(connection).await?; + update_videos_model(models, connection).await?; } info!("download videos in favorite: {} done.", favorite_model.f_id); Ok(()) } +/// 暂时这样做,后面提取成上下文 +#[allow(clippy::too_many_arguments)] pub async fn download_video_pages( bili_client: &BiliClient, video_model: video::Model, @@ -188,6 +196,7 @@ pub async fn download_video_pages( connection: &DatabaseConnection, semaphore: &Semaphore, downloader: &Downloader, + upper_path: &Path, upper_mutex: &(Mutex<()>, Mutex<()>), ) -> Result { let permit = semaphore.acquire().await; @@ -199,10 +208,7 @@ pub async fn download_video_pages( let base_path = Path::new(&video_model.path); let upper_id = video_model.upper_id.to_string(); - let base_upper_path = config_dir() - .unwrap() - .join("bili-sync") - .join("upper") + let base_upper_path = upper_path .join(upper_id.chars().next().unwrap().to_string()) .join(upper_id); @@ -292,12 +298,18 @@ pub async fn dispatch_download_page( .into_iter() .map(|page_model| download_page(bili_client, video_model, page_model, &child_semaphore, downloader)) .collect::>(); - // 任务结束会返回 Result - let mut models = Vec::with_capacity(10); + let mut should_error = false; while let Some(res) = tasks.next().await { match res { Ok(model) => { + if let Set(status) = model.download_status { + let status = PageStatus::new(status); + if status.should_run().iter().any(|v| *v) { + // 有一个分页没下载完成,就应该将视频本身标记为未完成 + should_error = true; + } + } models.push(model); } Err(e) => { @@ -308,13 +320,14 @@ pub async fn dispatch_download_page( } } if models.len() == 10 { - page::Entity::insert_many(std::mem::replace(&mut models, Vec::with_capacity(10))) - .exec(connection) - .await?; + update_pages_model(std::mem::replace(&mut models, Vec::with_capacity(10)), connection).await?; } } if !models.is_empty() { - page::Entity::insert_many(models).exec(connection).await?; + update_pages_model(models, connection).await?; + } + if should_error { + bail!("Some pages failed to download"); } Ok(()) } diff --git a/src/core/utils.rs b/src/core/utils.rs index 65e3bd1..c0fda00 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -225,6 +225,29 @@ pub async fn unhandled_videos_pages( .all(connection) .await?) } +/// 更新视频 model 的下载状态 +pub async fn update_videos_model(videos: Vec, connection: &DatabaseConnection) -> Result<()> { + video::Entity::insert_many(videos) + .on_conflict( + OnConflict::column(video::Column::Id) + .update_column(video::Column::DownloadStatus) + .to_owned(), + ) + .exec(connection) + .await?; + Ok(()) +} + +/// 更新视频页 model 的下载状态 +pub async fn update_pages_model(pages: Vec, connection: &DatabaseConnection) -> Result<()> { + let query = page::Entity::insert_many(pages).on_conflict( + OnConflict::column(page::Column::Id) + .update_columns([page::Column::DownloadStatus, page::Column::Path]) + .to_owned(), + ); + query.exec(connection).await?; + Ok(()) +} /// serde xml 似乎不太好用,先这么裸着写 /// (真是又臭又长啊