diff --git a/.env.example b/.env.example index 2c26b06..b505260 100644 --- a/.env.example +++ b/.env.example @@ -14,6 +14,9 @@ DATABASE_MAX_CONNECTIONS=10 # Redis REDIS_URL=redis://localhost:6379 +# Worker 并发(每个批量任务内同时处理的文件数) +WORKER_CONCURRENCY=4 + # JWT(网站/管理后台) JWT_SECRET=your-super-secret-key-change-in-production JWT_EXPIRY_HOURS=168 diff --git a/docs/deployment.md b/docs/deployment.md index 176abd0..3fb0089 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -113,6 +113,9 @@ DATABASE_URL=postgres://imageforge:devpassword@localhost:5432/imageforge # Redis REDIS_URL=redis://localhost:6379 +# Worker 并发(每个批量任务内同时处理的文件数) +WORKER_CONCURRENCY=4 + # JWT(网站/管理后台) JWT_SECRET=your-super-secret-key-change-in-production JWT_EXPIRY_HOURS=168 diff --git a/src/config.rs b/src/config.rs index d0bda07..59d73ea 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,6 +12,8 @@ pub struct Config { pub redis_url: String, + pub worker_concurrency: u32, + pub jwt_secret: String, pub jwt_expiry_hours: i64, @@ -60,6 +62,12 @@ impl Config { let redis_url = env_string("REDIS_URL") .ok_or_else(|| AppError::new(ErrorCode::InvalidRequest, "缺少环境变量 REDIS_URL"))?; + let worker_concurrency = env_u32("WORKER_CONCURRENCY").unwrap_or_else(|| { + std::thread::available_parallelism() + .map(|v| v.get() as u32) + .unwrap_or(4) + }); + let jwt_secret = env_string("JWT_SECRET") .ok_or_else(|| AppError::new(ErrorCode::InvalidRequest, "缺少环境变量 JWT_SECRET"))?; let jwt_expiry_hours = env_i64("JWT_EXPIRY_HOURS").unwrap_or(168); @@ -103,6 +111,7 @@ impl Config { database_url, database_max_connections, redis_url, + worker_concurrency, jwt_secret, jwt_expiry_hours, api_key_pepper, diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 794db02..3aaf3f6 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -8,7 +8,10 @@ use redis::streams::StreamReadOptions; use redis::AsyncCommands; use sqlx::FromRow; use std::net::IpAddr; +use std::sync::Arc; use std::time::Instant; +use tokio::sync::Semaphore; +use tokio::task::JoinSet; use uuid::Uuid; const STREAM_KEY: &str = "stream:compress_jobs"; @@ -149,6 +152,16 @@ struct TaskFileProcRow { status: String, } +#[derive(Clone)] +struct TaskContext { + api_key_id: Option, + source: String, + preserve_metadata: bool, + session_id: Option, + anon_ip: Option, + is_anonymous: bool, +} + async fn process_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> { let mut task: TaskProcRow = sqlx::query_as( r#" @@ -242,138 +255,54 @@ async fn process_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> { .as_deref() .and_then(|s| s.parse::().ok()); - for file in &mut files { - // Stop early if cancelled. - let status: Option = sqlx::query_scalar("SELECT status::text FROM tasks WHERE id = $1") - .bind(task_id) - .fetch_optional(&state.db) - .await - .unwrap_or(None); - if matches!(status.as_deref(), Some("cancelled")) { - break; - } + let ctx = TaskContext { + api_key_id: task.api_key_id, + source: task.source.clone(), + preserve_metadata: task.preserve_metadata, + session_id: task.session_id.clone(), + anon_ip, + is_anonymous: task.user_id.is_none(), + }; + let concurrency = state.config.worker_concurrency.max(1) as usize; + let semaphore = Arc::new(Semaphore::new(concurrency)); + let mut join_set = JoinSet::new(); + + for file in files.drain(..) { if file.status != "pending" { continue; } - let updated = sqlx::query("UPDATE task_files SET status = 'processing' WHERE id = $1 AND status = 'pending'") - .bind(file.id) - .execute(&state.db) + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let state = state.clone(); + let ctx = ctx.clone(); + let billing_ctx = billing_ctx.clone(); + let file_id = file.id; + + join_set.spawn(async move { + let _permit = permit; + if let Err(err) = process_task_file( + state, + task_id, + file, + level, + compression_rate, + max_width, + max_height, + ctx, + billing_ctx, + ) .await - .unwrap_or_else(|_| sqlx::postgres::PgQueryResult::default()); - if updated.rows_affected() == 0 { - continue; - } - - let Some(input_path) = file.storage_path.clone() else { - mark_file_failed(state, task_id, file.id, "原文件不存在").await?; - continue; - }; - - let input_bytes = match tokio::fs::read(&input_path).await { - Ok(v) => v, - Err(_) => { - mark_file_failed(state, task_id, file.id, "读取原文件失败").await?; - continue; + { + tracing::error!(task_id = %task_id, file_id = %file_id, error = %err, "file processing failed"); } - }; + }); + } - let format_in = parse_image_fmt(&file.original_format)?; - let format_out = parse_image_fmt(&file.output_format)?; - - let compressed = match compress::compress_image_bytes( - state, - &input_bytes, - format_in, - format_out, - level, - compression_rate, - max_width, - max_height, - task.preserve_metadata, - ) - .await - { - Ok(v) => v, - Err(err) => { - mark_file_failed(state, task_id, file.id, &err.message).await?; - let _ = tokio::fs::remove_file(&input_path).await; - continue; - } - }; - - let original_size = input_bytes.len() as u64; - let compressed_size = compressed.len() as u64; - let saved_percent = if original_size == 0 { - 0.0 - } else { - (original_size.saturating_sub(compressed_size) as f64) * 100.0 / (original_size as f64) - }; - let skip_charge = compression_rate == Some(100); - let charge_units = !skip_charge && compressed_size < original_size; - - // Anonymous quota enforcement requires session_id + client_ip. - if task.user_id.is_none() && charge_units { - let Some(session_id) = task.session_id.as_deref() else { - mark_file_failed(state, task_id, file.id, "匿名任务缺少 session_id").await?; - let _ = tokio::fs::remove_file(&input_path).await; - continue; - }; - let Some(ip) = anon_ip else { - mark_file_failed(state, task_id, file.id, "匿名任务缺少 client_ip").await?; - let _ = tokio::fs::remove_file(&input_path).await; - continue; - }; - if let Err(err) = quota::consume_anonymous_units(state, session_id, ip, 1).await { - mark_file_failed(state, task_id, file.id, &err.message).await?; - let _ = tokio::fs::remove_file(&input_path).await; - continue; - } + while let Some(result) = join_set.join_next().await { + if let Err(err) = result { + tracing::error!(task_id = %task_id, error = %err, "file worker panicked"); } - - let output_path = format!( - "{}/{}.{}", - state.config.storage_path, - file.id, - format_out.extension() - ); - if let Err(err) = tokio::fs::write(&output_path, &compressed).await { - mark_file_failed(state, task_id, file.id, "写入压缩文件失败").await?; - let _ = tokio::fs::remove_file(&input_path).await; - return Err(AppError::new(ErrorCode::StorageUnavailable, "写入压缩文件失败").with_source(err)); - } - - if let Err(err) = finalize_file( - state, - &billing_ctx, - task.api_key_id, - &task.source, - task_id, - file.id, - &output_path, - original_size as i64, - compressed_size as i64, - saved_percent, - format_in, - format_out, - charge_units, - ) - .await - { - // If quota exceeded for paid users, don't leave output behind. - if err.code == ErrorCode::QuotaExceeded { - let _ = tokio::fs::remove_file(&output_path).await; - mark_file_failed(state, task_id, file.id, &err.message).await?; - } else { - mark_file_failed(state, task_id, file.id, &err.message).await?; - } - let _ = tokio::fs::remove_file(&input_path).await; - continue; - } - - // Success: remove original. - let _ = tokio::fs::remove_file(&input_path).await; } finalize_task_status(state, task_id).await?; @@ -394,6 +323,165 @@ fn parse_image_fmt(value: &str) -> Result { } } +async fn is_task_cancelled(state: &AppState, task_id: Uuid) -> Result { + let status: Option = sqlx::query_scalar("SELECT status::text FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(&state.db) + .await + .map_err(|err| AppError::new(ErrorCode::Internal, "查询任务状态失败").with_source(err))?; + Ok(matches!(status.as_deref(), Some("cancelled"))) +} + +async fn process_task_file( + state: AppState, + task_id: Uuid, + file: TaskFileProcRow, + level: compress::CompressionLevel, + compression_rate: Option, + max_width: Option, + max_height: Option, + ctx: TaskContext, + billing_ctx: Option, +) -> Result<(), AppError> { + if is_task_cancelled(&state, task_id).await? { + return Ok(()); + } + + let updated = sqlx::query("UPDATE task_files SET status = 'processing' WHERE id = $1 AND status = 'pending'") + .bind(file.id) + .execute(&state.db) + .await + .unwrap_or_else(|_| sqlx::postgres::PgQueryResult::default()); + if updated.rows_affected() == 0 { + return Ok(()); + } + + if is_task_cancelled(&state, task_id).await? { + mark_file_failed(&state, task_id, file.id, "已取消").await?; + return Ok(()); + } + + let Some(input_path) = file.storage_path.clone() else { + mark_file_failed(&state, task_id, file.id, "原文件不存在").await?; + return Ok(()); + }; + + let input_bytes = match tokio::fs::read(&input_path).await { + Ok(v) => v, + Err(_) => { + mark_file_failed(&state, task_id, file.id, "读取原文件失败").await?; + return Ok(()); + } + }; + + let format_in = parse_image_fmt(&file.original_format)?; + let format_out = parse_image_fmt(&file.output_format)?; + + let compressed = match compress::compress_image_bytes( + &state, + &input_bytes, + format_in, + format_out, + level, + compression_rate, + max_width, + max_height, + ctx.preserve_metadata, + ) + .await + { + Ok(v) => v, + Err(err) => { + mark_file_failed(&state, task_id, file.id, &err.message).await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + } + }; + + if is_task_cancelled(&state, task_id).await? { + mark_file_failed(&state, task_id, file.id, "已取消").await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + } + + let original_size = input_bytes.len() as u64; + let compressed_size = compressed.len() as u64; + let saved_percent = if original_size == 0 { + 0.0 + } else { + (original_size.saturating_sub(compressed_size) as f64) * 100.0 / (original_size as f64) + }; + let skip_charge = compression_rate == Some(100); + let charge_units = !skip_charge && compressed_size < original_size; + + if ctx.is_anonymous && charge_units { + let Some(session_id) = ctx.session_id.as_deref() else { + mark_file_failed(&state, task_id, file.id, "匿名任务缺少 session_id").await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + }; + let Some(ip) = ctx.anon_ip else { + mark_file_failed(&state, task_id, file.id, "匿名任务缺少 client_ip").await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + }; + if let Err(err) = quota::consume_anonymous_units(&state, session_id, ip, 1).await { + mark_file_failed(&state, task_id, file.id, &err.message).await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + } + } + + let output_path = format!( + "{}/{}.{}", + state.config.storage_path, + file.id, + format_out.extension() + ); + if let Err(err) = tokio::fs::write(&output_path, &compressed).await { + mark_file_failed(&state, task_id, file.id, "写入压缩文件失败").await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Err(AppError::new(ErrorCode::StorageUnavailable, "写入压缩文件失败").with_source(err)); + } + + if is_task_cancelled(&state, task_id).await? { + let _ = tokio::fs::remove_file(&output_path).await; + mark_file_failed(&state, task_id, file.id, "已取消").await?; + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + } + + if let Err(err) = finalize_file( + &state, + &billing_ctx, + ctx.api_key_id, + &ctx.source, + task_id, + file.id, + &output_path, + original_size as i64, + compressed_size as i64, + saved_percent, + format_in, + format_out, + charge_units, + ) + .await + { + if err.code == ErrorCode::QuotaExceeded { + let _ = tokio::fs::remove_file(&output_path).await; + mark_file_failed(&state, task_id, file.id, &err.message).await?; + } else { + mark_file_failed(&state, task_id, file.id, &err.message).await?; + } + let _ = tokio::fs::remove_file(&input_path).await; + return Ok(()); + } + + let _ = tokio::fs::remove_file(&input_path).await; + Ok(()) +} + async fn finalize_file( state: &AppState, billing_ctx: &Option,