feat: 事件推送由 SSE 切换到 WebSocket (#386)

This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ
2025-07-11 00:14:20 +08:00
committed by GitHub
parent cc25749445
commit dd23d1db58
21 changed files with 783 additions and 307 deletions

View File

@@ -5,4 +5,4 @@ mod response;
mod routes;
mod wrapper;
pub use routes::{MAX_HISTORY_LOGS, MpscWriter, router};
pub use routes::{LogHelper, MAX_HISTORY_LOGS, router};

View File

@@ -158,7 +158,7 @@ pub struct DashBoardResponse {
}
#[derive(Serialize)]
pub struct SysInfoResponse {
pub struct SysInfo {
pub total_memory: u64,
pub used_memory: u64,
pub process_memory: u64,

View File

@@ -9,7 +9,6 @@ use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum::{Router, middleware};
use reqwest::{Method, StatusCode, header};
use url::Url;
use super::request::ImageProxyParams;
use crate::api::wrapper::ApiResponse;
@@ -19,11 +18,11 @@ use crate::config::VersionedConfig;
mod config;
mod dashboard;
mod me;
mod sse;
mod video_sources;
mod videos;
mod ws;
pub use sse::{MAX_HISTORY_LOGS, MpscWriter};
pub use ws::{LogHelper, MAX_HISTORY_LOGS};
pub fn router() -> Router {
Router::new().route("/image-proxy", get(image_proxy)).nest(
@@ -33,7 +32,7 @@ pub fn router() -> Router {
.merge(video_sources::router())
.merge(videos::router())
.merge(dashboard::router())
.merge(sse::router())
.merge(ws::router())
.layer(middleware::from_fn(auth)),
)
}
@@ -45,8 +44,9 @@ pub async fn auth(headers: HeaderMap, request: Request, next: Next) -> Result<Re
if headers
.get("Authorization")
.is_some_and(|v| v.to_str().is_ok_and(|s| s == token))
|| Url::parse(&format!("http://example.com/{}", request.uri()))
.is_ok_and(|url| url.query_pairs().any(|(k, v)| k == "token" && v == token))
|| headers
.get("Sec-WebSocket-Protocol")
.is_some_and(|v| v.to_str().is_ok_and(|s| s == token))
{
return Ok(next.run(request).await);
}

View File

@@ -1,98 +0,0 @@
mod mpsc;
use std::convert::Infallible;
use std::time::Duration;
use axum::response::Sse;
use axum::response::sse::{Event, KeepAlive};
use axum::routing::get;
use axum::{Extension, Router};
use futures::{Stream, StreamExt};
pub use mpsc::{MAX_HISTORY_LOGS, MpscWriter};
use sysinfo::{
CpuRefreshKind, DiskRefreshKind, Disks, MemoryRefreshKind, ProcessRefreshKind, RefreshKind, System, get_current_pid,
};
use tokio_stream::wrappers::{BroadcastStream, IntervalStream, WatchStream};
use crate::api::response::SysInfoResponse;
use crate::utils::task_notifier::TASK_STATUS_NOTIFIER;
pub(super) fn router() -> Router {
Router::new()
.route("/sse/logs", get(logs))
.route("/sse/tasks", get(get_tasks))
.route("/sse/sysinfo", get(get_sysinfo))
}
async fn get_tasks() -> Sse<impl futures::Stream<Item = Result<Event, Infallible>>> {
let stream = WatchStream::new(TASK_STATUS_NOTIFIER.subscribe()).filter_map(|status| async move {
match serde_json::to_string(&status) {
Ok(status) => Some(Ok(Event::default().data(status))),
Err(_) => None,
}
});
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn logs(Extension(log_writer): Extension<MpscWriter>) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let history = log_writer.log_history.lock();
let rx = log_writer.sender.subscribe();
let history_logs: Vec<String> = history.iter().cloned().collect();
drop(history);
let history_stream = { futures::stream::iter(history_logs.into_iter().map(|msg| Ok(Event::default().data(msg)))) };
let stream = BroadcastStream::new(rx).filter_map(async |msg| match msg {
Ok(log_message) => Some(Ok(Event::default().data(log_message))),
Err(e) => {
error!("Log stream error: {:?}", e);
None
}
});
Sse::new(history_stream.chain(stream)).keep_alive(KeepAlive::default())
}
async fn get_sysinfo() -> Sse<impl futures::Stream<Item = Result<Event, Infallible>>> {
let sys_refresh_kind = sys_refresh_kind();
let disk_refresh_kind = disk_refresh_kind();
let mut system = System::new();
let mut disks = Disks::new();
// safety: this functions always returns Ok on Linux/MacOS/Windows
let self_pid = get_current_pid().unwrap();
let stream = IntervalStream::new(tokio::time::interval(Duration::from_secs(2))).filter_map(move |_| {
system.refresh_specifics(sys_refresh_kind);
disks.refresh_specifics(true, disk_refresh_kind);
let process = match system.process(self_pid) {
Some(p) => p,
None => return futures::future::ready(None),
};
let info = SysInfoResponse {
total_memory: system.total_memory(),
used_memory: system.used_memory(),
process_memory: process.memory(),
used_cpu: system.global_cpu_usage(),
process_cpu: process.cpu_usage() / system.cpus().len() as f32,
total_disk: disks.iter().map(|d| d.total_space()).sum(),
available_disk: disks.iter().map(|d| d.available_space()).sum(),
};
match serde_json::to_string(&info) {
Ok(json) => futures::future::ready(Some(Ok(Event::default().data(json)))),
Err(_) => {
error!("Failed to serialize system info");
futures::future::ready(None)
}
}
});
Sse::new(stream).keep_alive(KeepAlive::default())
}
fn sys_refresh_kind() -> RefreshKind {
RefreshKind::nothing()
.with_cpu(CpuRefreshKind::nothing().with_cpu_usage())
.with_memory(MemoryRefreshKind::nothing().with_ram())
.with_processes(ProcessRefreshKind::nothing().with_cpu().with_memory())
}
fn disk_refresh_kind() -> DiskRefreshKind {
DiskRefreshKind::nothing().with_storage()
}

View File

@@ -7,18 +7,19 @@ use tracing_subscriber::fmt::MakeWriter;
pub const MAX_HISTORY_LOGS: usize = 30;
pub struct MpscWriter {
/// LogHelper 维护了日志发送器和一个日志历史记录的缓冲区
pub struct LogHelper {
pub sender: broadcast::Sender<String>,
pub log_history: Arc<Mutex<VecDeque<String>>>,
}
impl MpscWriter {
impl LogHelper {
pub fn new(sender: broadcast::Sender<String>, log_history: Arc<Mutex<VecDeque<String>>>) -> Self {
MpscWriter { sender, log_history }
LogHelper { sender, log_history }
}
}
impl<'a> MakeWriter<'a> for MpscWriter {
impl<'a> MakeWriter<'a> for LogHelper {
type Writer = Self;
fn make_writer(&'a self) -> Self::Writer {
@@ -26,7 +27,7 @@ impl<'a> MakeWriter<'a> for MpscWriter {
}
}
impl std::io::Write for MpscWriter {
impl std::io::Write for LogHelper {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let log_message = String::from_utf8_lossy(buf).to_string();
let _ = self.sender.send(log_message.clone());
@@ -43,9 +44,9 @@ impl std::io::Write for MpscWriter {
}
}
impl Clone for MpscWriter {
impl Clone for LogHelper {
fn clone(&self) -> Self {
MpscWriter {
LogHelper {
sender: self.sender.clone(),
log_history: self.log_history.clone(),
}

View File

@@ -0,0 +1,263 @@
mod log_helper;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use axum::extract::WebSocketUpgrade;
use axum::extract::ws::{Message, WebSocket};
use axum::response::IntoResponse;
use axum::routing::any;
use axum::{Extension, Router};
use dashmap::DashMap;
use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt, future};
pub use log_helper::{LogHelper, MAX_HISTORY_LOGS};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use sysinfo::{
CpuRefreshKind, DiskRefreshKind, Disks, MemoryRefreshKind, ProcessRefreshKind, RefreshKind, System, get_current_pid,
};
use tokio::pin;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::{BroadcastStream, IntervalStream, WatchStream};
use uuid::Uuid;
use crate::api::response::SysInfo;
use crate::utils::task_notifier::{TASK_STATUS_NOTIFIER, TaskStatus};
static WEBSOCKET_HANDLER: LazyLock<WebSocketHandler> = LazyLock::new(WebSocketHandler::new);
pub(super) fn router() -> Router {
Router::new().route("/ws", any(websocket_handler))
}
async fn websocket_handler(ws: WebSocketUpgrade, Extension(log_writer): Extension<LogHelper>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, log_writer))
}
// 事件类型枚举
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
enum EventType {
Logs,
Tasks,
SysInfo,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
enum ClientEvent {
Subscribe(EventType),
Unsubscribe(EventType),
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
enum ServerEvent {
Logs(String),
Tasks(Arc<TaskStatus>),
SysInfo(Arc<SysInfo>),
}
struct WebSocketHandler {
sysinfo_subscribers: Arc<DashMap<Uuid, tokio::sync::mpsc::Sender<ServerEvent>>>,
sysinfo_handles: RwLock<Option<JoinHandle<()>>>,
}
impl WebSocketHandler {
fn new() -> Self {
Self {
sysinfo_subscribers: Arc::new(DashMap::new()),
sysinfo_handles: RwLock::new(None),
}
}
async fn handle_sender(
&self,
mut sender: SplitSink<WebSocket, Message>,
mut rx: tokio::sync::mpsc::Receiver<ServerEvent>,
) {
while let Some(event) = rx.recv().await {
match serde_json::to_string(&event) {
Ok(text) => {
if let Err(e) = sender.send(Message::Text(text.into())).await {
error!("Failed to send message: {:?}", e);
break;
}
}
Err(e) => {
error!("Failed to serialize event: {:?}", e);
}
}
}
}
async fn handle_receiver(
&self,
mut receiver: SplitStream<WebSocket>,
tx: tokio::sync::mpsc::Sender<ServerEvent>,
uuid: Uuid,
log_writer: LogHelper,
) {
// 日志和任务状态的处理本身就是由 stream 驱动的,可以直接为每个 ws 连接维护独立的任务处理器
// 系统信息是服务端轮询然后推送的,如果单独维护会导致每个连接都独立轮询系统信息,造成不必要的浪费
// 因此采用了全局的订阅者管理,所有连接共享同一个系统信息轮询任务
let (mut log_handle, mut task_handle) = (None, None);
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
match serde_json::from_str::<ClientEvent>(&text) {
Ok(ClientEvent::Subscribe(event_type)) => match event_type {
EventType::Logs => {
if log_handle.as_ref().is_none_or(|h: &JoinHandle<()>| h.is_finished()) {
let log_writer_clone = log_writer.clone();
let tx_clone = tx.clone();
let history = log_writer_clone.log_history.lock();
let history_logs: Vec<String> = history.iter().cloned().collect();
drop(history);
log_handle = Some(tokio::spawn(async move {
let rx = log_writer_clone.sender.subscribe();
let log_stream = futures::stream::iter(history_logs.into_iter())
.chain(BroadcastStream::new(rx).filter_map(async |msg| msg.ok()))
.map(|msg| ServerEvent::Logs(msg));
pin!(log_stream);
while let Some(event) = log_stream.next().await {
if let Err(e) = tx_clone.send(event).await {
error!("Failed to send log event: {:?}", e);
break;
}
}
}));
}
}
EventType::Tasks => {
if task_handle.as_ref().is_none_or(|h: &JoinHandle<()>| h.is_finished()) {
let tx_clone = tx.clone();
task_handle = Some(tokio::spawn(async move {
let mut stream = WatchStream::new(TASK_STATUS_NOTIFIER.subscribe())
.map(|status| ServerEvent::Tasks(status));
while let Some(event) = stream.next().await {
if let Err(e) = tx_clone.send(event).await {
error!("Failed to send task status: {:?}", e);
break;
}
}
}));
}
}
EventType::SysInfo => self.add_sysinfo_subscriber(uuid, tx.clone()).await,
},
Ok(ClientEvent::Unsubscribe(event_type)) => match event_type {
EventType::Logs => {
if let Some(handle) = log_handle.take() {
handle.abort();
}
}
EventType::Tasks => {
if let Some(handle) = task_handle.take() {
handle.abort();
}
}
EventType::SysInfo => {
self.remove_sysinfo_subscriber(uuid).await;
}
},
Err(e) => {
error!("Failed to parse client message: {:?}", e);
}
}
}
}
if let Some(handle) = log_handle {
handle.abort();
}
if let Some(handle) = task_handle {
handle.abort();
}
self.remove_sysinfo_subscriber(uuid).await;
}
// 添加订阅者
async fn add_sysinfo_subscriber(&self, uuid: Uuid, sender: tokio::sync::mpsc::Sender<ServerEvent>) {
self.sysinfo_subscribers.insert(uuid, sender);
if self.sysinfo_subscribers.len() > 0
&& self
.sysinfo_handles
.read()
.as_ref()
.is_none_or(|h: &JoinHandle<()>| h.is_finished())
{
let sysinfo_subscribers = self.sysinfo_subscribers.clone();
let mut write_guard = self.sysinfo_handles.write();
if write_guard.as_ref().is_some_and(|h: &JoinHandle<()>| !h.is_finished()) {
return;
}
*write_guard = Some(tokio::spawn(async move {
let mut system = System::new();
let mut disks = Disks::new();
let sys_refresh_kind = sys_refresh_kind();
let disk_refresh_kind = disk_refresh_kind();
// 对于 linux/mac/windows 平台,该方法永远返回 Some(pid)expect 基本是安全的
let self_pid = get_current_pid().expect("Unsupported platform");
let mut stream =
IntervalStream::new(tokio::time::interval(Duration::from_secs(2))).filter_map(move |_| {
system.refresh_specifics(sys_refresh_kind);
disks.refresh_specifics(true, disk_refresh_kind);
let process = match system.process(self_pid) {
Some(p) => p,
None => return futures::future::ready(None),
};
futures::future::ready(Some(SysInfo {
total_memory: system.total_memory(),
used_memory: system.used_memory(),
process_memory: process.memory(),
used_cpu: system.global_cpu_usage(),
process_cpu: process.cpu_usage() / system.cpus().len() as f32,
total_disk: disks.iter().map(|d| d.total_space()).sum(),
available_disk: disks.iter().map(|d| d.available_space()).sum(),
}))
});
while let Some(sys_info) = stream.next().await {
let sys_info = Arc::new(sys_info);
future::join_all(sysinfo_subscribers.iter().map(async |subscriber| {
if let Err(e) = subscriber.send(ServerEvent::SysInfo(sys_info.clone())).await {
error!(
"Failed to send sysinfo event to subscriber {}: {:?}",
subscriber.key(),
e
);
}
}))
.await;
}
}));
}
}
async fn remove_sysinfo_subscriber(&self, uuid: Uuid) {
self.sysinfo_subscribers.remove(&uuid);
if self.sysinfo_subscribers.is_empty() {
if let Some(handle) = self.sysinfo_handles.write().take() {
handle.abort();
}
}
}
}
async fn handle_socket(socket: WebSocket, log_writer: LogHelper) {
let (ws_sender, ws_receiver) = socket.split();
let uuid = Uuid::new_v4();
let (tx, rx) = tokio::sync::mpsc::channel(100);
tokio::spawn(WEBSOCKET_HANDLER.handle_sender(ws_sender, rx));
tokio::spawn(WEBSOCKET_HANDLER.handle_receiver(ws_receiver, tx, uuid, log_writer));
}
fn sys_refresh_kind() -> RefreshKind {
RefreshKind::nothing()
.with_cpu(CpuRefreshKind::nothing().with_cpu_usage())
.with_memory(MemoryRefreshKind::nothing().with_ram())
.with_processes(ProcessRefreshKind::nothing().with_cpu().with_memory())
}
fn disk_refresh_kind() -> DiskRefreshKind {
DiskRefreshKind::nothing().with_storage()
}

View File

@@ -24,7 +24,7 @@ use task::{http_server, video_downloader};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::api::{MAX_HISTORY_LOGS, MpscWriter};
use crate::api::{LogHelper, MAX_HISTORY_LOGS};
use crate::config::{ARGS, VersionedConfig};
use crate::database::setup_database;
use crate::utils::init_logger;
@@ -77,10 +77,10 @@ fn spawn_task(
}
/// 初始化日志系统、打印欢迎信息,初始化数据库连接和全局配置
async fn init() -> (Arc<DatabaseConnection>, MpscWriter) {
async fn init() -> (Arc<DatabaseConnection>, LogHelper) {
let (tx, _rx) = tokio::sync::broadcast::channel(30);
let log_history = Arc::new(Mutex::new(VecDeque::with_capacity(MAX_HISTORY_LOGS + 1)));
let log_writer = MpscWriter::new(tx, log_history.clone());
let log_writer = LogHelper::new(tx, log_history.clone());
init_logger(&ARGS.log_level, Some(log_writer.clone()));
info!("欢迎使用 Bili-Sync当前程序版本{}", config::version());

View File

@@ -11,7 +11,7 @@ use reqwest::StatusCode;
use rust_embed_for_web::{EmbedableFile, RustEmbed};
use sea_orm::DatabaseConnection;
use crate::api::{MpscWriter, router};
use crate::api::{LogHelper, router};
use crate::bilibili::BiliClient;
use crate::config::VersionedConfig;
@@ -23,7 +23,7 @@ struct Asset;
pub async fn http_server(
database_connection: Arc<DatabaseConnection>,
bili_client: Arc<BiliClient>,
log_writer: MpscWriter,
log_writer: LogHelper,
) -> Result<()> {
let app = router()
.fallback_service(get(frontend_files))

View File

@@ -11,9 +11,9 @@ use tracing_subscriber::fmt;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use crate::api::MpscWriter;
use crate::api::LogHelper;
pub fn init_logger(log_level: &str, log_writer: Option<MpscWriter>) {
pub fn init_logger(log_level: &str, log_writer: Option<LogHelper>) {
let log = tracing_subscriber::fmt::Subscriber::builder()
.compact()
.with_env_filter(tracing_subscriber::EnvFilter::builder().parse_lossy(log_level))