Parallelize worker batch processing

This commit is contained in:
2025-12-20 18:37:27 +08:00
parent 24a4f81c41
commit df9c40e456
4 changed files with 226 additions and 123 deletions

View File

@@ -14,6 +14,9 @@ DATABASE_MAX_CONNECTIONS=10
# Redis
REDIS_URL=redis://localhost:6379
# Worker 并发(每个批量任务内同时处理的文件数)
WORKER_CONCURRENCY=4
# JWT网站/管理后台)
JWT_SECRET=your-super-secret-key-change-in-production
JWT_EXPIRY_HOURS=168

View File

@@ -113,6 +113,9 @@ DATABASE_URL=postgres://imageforge:devpassword@localhost:5432/imageforge
# Redis
REDIS_URL=redis://localhost:6379
# Worker 并发(每个批量任务内同时处理的文件数)
WORKER_CONCURRENCY=4
# JWT网站/管理后台)
JWT_SECRET=your-super-secret-key-change-in-production
JWT_EXPIRY_HOURS=168

View File

@@ -12,6 +12,8 @@ pub struct Config {
pub redis_url: String,
pub worker_concurrency: u32,
pub jwt_secret: String,
pub jwt_expiry_hours: i64,
@@ -60,6 +62,12 @@ impl Config {
let redis_url = env_string("REDIS_URL")
.ok_or_else(|| AppError::new(ErrorCode::InvalidRequest, "缺少环境变量 REDIS_URL"))?;
let worker_concurrency = env_u32("WORKER_CONCURRENCY").unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|v| v.get() as u32)
.unwrap_or(4)
});
let jwt_secret = env_string("JWT_SECRET")
.ok_or_else(|| AppError::new(ErrorCode::InvalidRequest, "缺少环境变量 JWT_SECRET"))?;
let jwt_expiry_hours = env_i64("JWT_EXPIRY_HOURS").unwrap_or(168);
@@ -103,6 +111,7 @@ impl Config {
database_url,
database_max_connections,
redis_url,
worker_concurrency,
jwt_secret,
jwt_expiry_hours,
api_key_pepper,

View File

@@ -8,7 +8,10 @@ use redis::streams::StreamReadOptions;
use redis::AsyncCommands;
use sqlx::FromRow;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use uuid::Uuid;
const STREAM_KEY: &str = "stream:compress_jobs";
@@ -149,6 +152,16 @@ struct TaskFileProcRow {
status: String,
}
#[derive(Clone)]
struct TaskContext {
api_key_id: Option<Uuid>,
source: String,
preserve_metadata: bool,
session_id: Option<String>,
anon_ip: Option<IpAddr>,
is_anonymous: bool,
}
async fn process_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
let mut task: TaskProcRow = sqlx::query_as(
r#"
@@ -242,138 +255,54 @@ async fn process_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
.as_deref()
.and_then(|s| s.parse::<IpAddr>().ok());
for file in &mut files {
// Stop early if cancelled.
let status: Option<String> = sqlx::query_scalar("SELECT status::text FROM tasks WHERE id = $1")
.bind(task_id)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
if matches!(status.as_deref(), Some("cancelled")) {
break;
}
let ctx = TaskContext {
api_key_id: task.api_key_id,
source: task.source.clone(),
preserve_metadata: task.preserve_metadata,
session_id: task.session_id.clone(),
anon_ip,
is_anonymous: task.user_id.is_none(),
};
let concurrency = state.config.worker_concurrency.max(1) as usize;
let semaphore = Arc::new(Semaphore::new(concurrency));
let mut join_set = JoinSet::new();
for file in files.drain(..) {
if file.status != "pending" {
continue;
}
let updated = sqlx::query("UPDATE task_files SET status = 'processing' WHERE id = $1 AND status = 'pending'")
.bind(file.id)
.execute(&state.db)
.await
.unwrap_or_else(|_| sqlx::postgres::PgQueryResult::default());
if updated.rows_affected() == 0 {
continue;
}
let permit = semaphore.clone().acquire_owned().await.unwrap();
let state = state.clone();
let ctx = ctx.clone();
let billing_ctx = billing_ctx.clone();
let file_id = file.id;
let Some(input_path) = file.storage_path.clone() else {
mark_file_failed(state, task_id, file.id, "原文件不存在").await?;
continue;
};
let input_bytes = match tokio::fs::read(&input_path).await {
Ok(v) => v,
Err(_) => {
mark_file_failed(state, task_id, file.id, "读取原文件失败").await?;
continue;
}
};
let format_in = parse_image_fmt(&file.original_format)?;
let format_out = parse_image_fmt(&file.output_format)?;
let compressed = match compress::compress_image_bytes(
join_set.spawn(async move {
let _permit = permit;
if let Err(err) = process_task_file(
state,
&input_bytes,
format_in,
format_out,
task_id,
file,
level,
compression_rate,
max_width,
max_height,
task.preserve_metadata,
ctx,
billing_ctx,
)
.await
{
Ok(v) => v,
Err(err) => {
mark_file_failed(state, task_id, file.id, &err.message).await?;
let _ = tokio::fs::remove_file(&input_path).await;
continue;
}
};
let original_size = input_bytes.len() as u64;
let compressed_size = compressed.len() as u64;
let saved_percent = if original_size == 0 {
0.0
} else {
(original_size.saturating_sub(compressed_size) as f64) * 100.0 / (original_size as f64)
};
let skip_charge = compression_rate == Some(100);
let charge_units = !skip_charge && compressed_size < original_size;
// Anonymous quota enforcement requires session_id + client_ip.
if task.user_id.is_none() && charge_units {
let Some(session_id) = task.session_id.as_deref() else {
mark_file_failed(state, task_id, file.id, "匿名任务缺少 session_id").await?;
let _ = tokio::fs::remove_file(&input_path).await;
continue;
};
let Some(ip) = anon_ip else {
mark_file_failed(state, task_id, file.id, "匿名任务缺少 client_ip").await?;
let _ = tokio::fs::remove_file(&input_path).await;
continue;
};
if let Err(err) = quota::consume_anonymous_units(state, session_id, ip, 1).await {
mark_file_failed(state, task_id, file.id, &err.message).await?;
let _ = tokio::fs::remove_file(&input_path).await;
continue;
tracing::error!(task_id = %task_id, file_id = %file_id, error = %err, "file processing failed");
}
});
}
let output_path = format!(
"{}/{}.{}",
state.config.storage_path,
file.id,
format_out.extension()
);
if let Err(err) = tokio::fs::write(&output_path, &compressed).await {
mark_file_failed(state, task_id, file.id, "写入压缩文件失败").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Err(AppError::new(ErrorCode::StorageUnavailable, "写入压缩文件失败").with_source(err));
while let Some(result) = join_set.join_next().await {
if let Err(err) = result {
tracing::error!(task_id = %task_id, error = %err, "file worker panicked");
}
if let Err(err) = finalize_file(
state,
&billing_ctx,
task.api_key_id,
&task.source,
task_id,
file.id,
&output_path,
original_size as i64,
compressed_size as i64,
saved_percent,
format_in,
format_out,
charge_units,
)
.await
{
// If quota exceeded for paid users, don't leave output behind.
if err.code == ErrorCode::QuotaExceeded {
let _ = tokio::fs::remove_file(&output_path).await;
mark_file_failed(state, task_id, file.id, &err.message).await?;
} else {
mark_file_failed(state, task_id, file.id, &err.message).await?;
}
let _ = tokio::fs::remove_file(&input_path).await;
continue;
}
// Success: remove original.
let _ = tokio::fs::remove_file(&input_path).await;
}
finalize_task_status(state, task_id).await?;
@@ -394,6 +323,165 @@ fn parse_image_fmt(value: &str) -> Result<compress::ImageFmt, AppError> {
}
}
async fn is_task_cancelled(state: &AppState, task_id: Uuid) -> Result<bool, AppError> {
let status: Option<String> = sqlx::query_scalar("SELECT status::text FROM tasks WHERE id = $1")
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务状态失败").with_source(err))?;
Ok(matches!(status.as_deref(), Some("cancelled")))
}
async fn process_task_file(
state: AppState,
task_id: Uuid,
file: TaskFileProcRow,
level: compress::CompressionLevel,
compression_rate: Option<u8>,
max_width: Option<u32>,
max_height: Option<u32>,
ctx: TaskContext,
billing_ctx: Option<billing::BillingContext>,
) -> Result<(), AppError> {
if is_task_cancelled(&state, task_id).await? {
return Ok(());
}
let updated = sqlx::query("UPDATE task_files SET status = 'processing' WHERE id = $1 AND status = 'pending'")
.bind(file.id)
.execute(&state.db)
.await
.unwrap_or_else(|_| sqlx::postgres::PgQueryResult::default());
if updated.rows_affected() == 0 {
return Ok(());
}
if is_task_cancelled(&state, task_id).await? {
mark_file_failed(&state, task_id, file.id, "已取消").await?;
return Ok(());
}
let Some(input_path) = file.storage_path.clone() else {
mark_file_failed(&state, task_id, file.id, "原文件不存在").await?;
return Ok(());
};
let input_bytes = match tokio::fs::read(&input_path).await {
Ok(v) => v,
Err(_) => {
mark_file_failed(&state, task_id, file.id, "读取原文件失败").await?;
return Ok(());
}
};
let format_in = parse_image_fmt(&file.original_format)?;
let format_out = parse_image_fmt(&file.output_format)?;
let compressed = match compress::compress_image_bytes(
&state,
&input_bytes,
format_in,
format_out,
level,
compression_rate,
max_width,
max_height,
ctx.preserve_metadata,
)
.await
{
Ok(v) => v,
Err(err) => {
mark_file_failed(&state, task_id, file.id, &err.message).await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
}
};
if is_task_cancelled(&state, task_id).await? {
mark_file_failed(&state, task_id, file.id, "已取消").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
}
let original_size = input_bytes.len() as u64;
let compressed_size = compressed.len() as u64;
let saved_percent = if original_size == 0 {
0.0
} else {
(original_size.saturating_sub(compressed_size) as f64) * 100.0 / (original_size as f64)
};
let skip_charge = compression_rate == Some(100);
let charge_units = !skip_charge && compressed_size < original_size;
if ctx.is_anonymous && charge_units {
let Some(session_id) = ctx.session_id.as_deref() else {
mark_file_failed(&state, task_id, file.id, "匿名任务缺少 session_id").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
};
let Some(ip) = ctx.anon_ip else {
mark_file_failed(&state, task_id, file.id, "匿名任务缺少 client_ip").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
};
if let Err(err) = quota::consume_anonymous_units(&state, session_id, ip, 1).await {
mark_file_failed(&state, task_id, file.id, &err.message).await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
}
}
let output_path = format!(
"{}/{}.{}",
state.config.storage_path,
file.id,
format_out.extension()
);
if let Err(err) = tokio::fs::write(&output_path, &compressed).await {
mark_file_failed(&state, task_id, file.id, "写入压缩文件失败").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Err(AppError::new(ErrorCode::StorageUnavailable, "写入压缩文件失败").with_source(err));
}
if is_task_cancelled(&state, task_id).await? {
let _ = tokio::fs::remove_file(&output_path).await;
mark_file_failed(&state, task_id, file.id, "已取消").await?;
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
}
if let Err(err) = finalize_file(
&state,
&billing_ctx,
ctx.api_key_id,
&ctx.source,
task_id,
file.id,
&output_path,
original_size as i64,
compressed_size as i64,
saved_percent,
format_in,
format_out,
charge_units,
)
.await
{
if err.code == ErrorCode::QuotaExceeded {
let _ = tokio::fs::remove_file(&output_path).await;
mark_file_failed(&state, task_id, file.id, &err.message).await?;
} else {
mark_file_failed(&state, task_id, file.id, &err.message).await?;
}
let _ = tokio::fs::remove_file(&input_path).await;
return Ok(());
}
let _ = tokio::fs::remove_file(&input_path).await;
Ok(())
}
async fn finalize_file(
state: &AppState,
billing_ctx: &Option<billing::BillingContext>,