Files
ystp/src/api/tasks.rs

988 lines
33 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use crate::api::context;
use crate::api::envelope::Envelope;
use crate::error::{AppError, ErrorCode};
use crate::services::billing;
use crate::services::billing::{BillingContext, Plan};
use crate::services::compress;
use crate::services::compress::{CompressionLevel, ImageFmt};
use crate::services::idempotency;
use crate::state::AppState;
use axum::extract::{ConnectInfo, Multipart, Path, State};
use axum::http::HeaderMap;
use axum::routing::{delete, get, post};
use axum::{Json, Router};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sqlx::FromRow;
use std::net::{IpAddr, SocketAddr};
use tokio::io::AsyncWriteExt;
use uuid::Uuid;
pub fn router() -> Router<AppState> {
Router::new()
.route("/compress/batch", post(create_batch_task))
.route("/compress/tasks/{task_id}", get(get_task))
.route("/compress/tasks/{task_id}/cancel", post(cancel_task))
.route("/compress/tasks/{task_id}", delete(delete_task))
}
#[derive(Debug, Serialize, Deserialize)]
struct BatchCreateResponse {
task_id: Uuid,
total_files: i32,
status: String,
status_url: String,
}
#[derive(Debug)]
struct BatchFileInput {
file_id: Uuid,
original_name: String,
original_format: ImageFmt,
output_format: ImageFmt,
original_size: u64,
storage_path: String,
}
#[derive(Debug)]
struct BatchOptions {
level: CompressionLevel,
compression_rate: Option<u8>,
output_format: Option<ImageFmt>,
max_width: Option<u32>,
max_height: Option<u32>,
preserve_metadata: bool,
}
async fn create_batch_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
mut multipart: Multipart,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<BatchCreateResponse>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
if state.config.storage_type.to_ascii_lowercase() != "local" {
return Err(AppError::new(
ErrorCode::StorageUnavailable,
"当前仅支持本地存储STORAGE_TYPE=local",
));
}
let idempotency_key = headers
.get("idempotency-key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(str::to_string);
let idempotency_scope = match &principal {
context::Principal::User { user_id, .. } => Some(idempotency::Scope::User(*user_id)),
context::Principal::ApiKey { api_key_id, .. } => Some(idempotency::Scope::ApiKey(*api_key_id)),
_ => None,
};
let task_id = Uuid::new_v4();
let (files, opts, request_hash) = parse_batch_request(&state, task_id, &mut multipart).await?;
if files.is_empty() {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "缺少 files[]"));
}
let mut idem_acquired = false;
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
match idempotency::begin(
&state,
scope,
idem_key,
&request_hash,
state.config.idempotency_ttl_hours as i64,
)
.await?
{
idempotency::BeginResult::Replay { response_body, .. } => {
cleanup_file_paths(&files).await;
let resp: BatchCreateResponse =
serde_json::from_value(response_body).map_err(|err| {
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
})?;
return Ok((jar, Json(Envelope { success: true, data: resp })));
}
idempotency::BeginResult::InProgress => {
cleanup_file_paths(&files).await;
if let Some((_status, body)) = idempotency::wait_for_replay(
&state,
scope,
idem_key,
&request_hash,
10_000,
)
.await?
{
let resp: BatchCreateResponse =
serde_json::from_value(body).map_err(|err| {
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
})?;
return Ok((jar, Json(Envelope { success: true, data: resp })));
}
return Err(AppError::new(
ErrorCode::InvalidRequest,
"请求正在处理中,请稍后重试",
));
}
idempotency::BeginResult::Acquired { .. } => {
idem_acquired = true;
}
}
}
let create_result: Result<BatchCreateResponse, AppError> = (async {
let (retention, task_owner, source) = match &principal {
context::Principal::Anonymous { session_id } => {
enforce_batch_limits_anonymous(&state, &files)?;
let remaining = anonymous_remaining_units(&state, session_id, ip).await?;
if remaining < files.len() as i64 {
return Err(AppError::new(
ErrorCode::QuotaExceeded,
"匿名试用次数已用完(每日 10 次)",
));
}
Ok((
Duration::hours(state.config.anon_retention_hours as i64),
TaskOwner::Anonymous {
session_id: session_id.clone(),
},
"web",
))
}
context::Principal::User {
user_id,
email_verified,
..
} => {
if !email_verified {
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
}
let billing = billing::get_user_billing(&state, *user_id).await?;
enforce_batch_limits_plan(&billing.plan, &files)?;
ensure_quota_available(&state, &billing, files.len() as i32).await?;
Ok((
Duration::days(billing.plan.retention_days as i64),
TaskOwner::User { user_id: *user_id },
"web",
))
}
context::Principal::ApiKey {
user_id,
api_key_id,
email_verified,
..
} => {
if !email_verified {
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
}
let billing = billing::get_user_billing(&state, *user_id).await?;
if !billing.plan.feature_api_enabled {
return Err(AppError::new(ErrorCode::Forbidden, "当前套餐未开通 API"));
}
enforce_batch_limits_plan(&billing.plan, &files)?;
ensure_quota_available(&state, &billing, files.len() as i32).await?;
Ok((
Duration::days(billing.plan.retention_days as i64),
TaskOwner::ApiKey {
user_id: *user_id,
api_key_id: *api_key_id,
},
"api",
))
}
}?;
tokio::fs::create_dir_all(&state.config.storage_path)
.await
.map_err(|err| {
AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err)
})?;
let expires_at = Utc::now() + retention;
let (user_id, session_id, api_key_id) = match &task_owner {
TaskOwner::Anonymous { session_id } => (None, Some(session_id.clone()), None),
TaskOwner::User { user_id } => (Some(*user_id), None, None),
TaskOwner::ApiKey { user_id, api_key_id } => (Some(*user_id), None, Some(*api_key_id)),
};
let total_original_size: i64 = files.iter().map(|f| f.original_size as i64).sum();
let mut tx = state
.db
.begin()
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "开启事务失败").with_source(err))?;
sqlx::query(
r#"
INSERT INTO tasks (
id, user_id, session_id, api_key_id, client_ip, source, status,
compression_level, output_format, max_width, max_height, preserve_metadata,
compression_rate,
total_files, completed_files, failed_files,
total_original_size, total_compressed_size,
expires_at
) VALUES (
$1, $2, $3, $4, $5::inet, $6::task_source, 'pending',
$7::compression_level, $8, $9, $10, $11, $12,
$13, 0, 0,
$14, 0,
$15
)
"#,
)
.bind(task_id)
.bind(user_id)
.bind(session_id)
.bind(api_key_id)
.bind(ip.to_string())
.bind(source)
.bind(opts.level.as_str())
.bind(opts.output_format.map(|f| f.as_str()))
.bind(opts.max_width.map(|v| v as i32))
.bind(opts.max_height.map(|v| v as i32))
.bind(false)
.bind(opts.compression_rate.map(|v| v as i16))
.bind(files.len() as i32)
.bind(total_original_size)
.bind(expires_at)
.execute(&mut *tx)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "创建任务失败").with_source(err))?;
for file in &files {
sqlx::query(
r#"
INSERT INTO task_files (
id, task_id,
original_name, original_format, output_format,
original_size,
storage_path, status
) VALUES (
$1, $2,
$3, $4, $5,
$6,
$7, 'pending'
)
"#,
)
.bind(file.file_id)
.bind(task_id)
.bind(&file.original_name)
.bind(file.original_format.as_str())
.bind(file.output_format.as_str())
.bind(file.original_size as i64)
.bind(&file.storage_path)
.execute(&mut *tx)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "创建文件记录失败").with_source(err))?;
}
tx.commit()
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "提交事务失败").with_source(err))?;
if let Err(err) = enqueue_task(&state, task_id).await {
let _ = sqlx::query("UPDATE tasks SET status = 'failed', error_message = $2 WHERE id = $1")
.bind(task_id)
.bind("队列提交失败")
.execute(&state.db)
.await;
return Err(err);
}
Ok(BatchCreateResponse {
task_id,
total_files: files.len() as i32,
status: "pending".to_string(),
status_url: format!("/api/v1/compress/tasks/{task_id}"),
})
})
.await;
match create_result {
Ok(resp) => {
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
if idem_acquired {
let _ = idempotency::complete(
&state,
scope,
idem_key,
&request_hash,
200,
serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null),
)
.await;
}
}
Ok((jar, Json(Envelope { success: true, data: resp })))
}
Err(err) => {
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
if idem_acquired {
let _ = idempotency::abort(&state, scope, idem_key, &request_hash).await;
}
}
cleanup_file_paths(&files).await;
Err(err)
}
}
}
#[derive(Debug)]
enum TaskOwner {
Anonymous { session_id: String },
User { user_id: Uuid },
ApiKey { user_id: Uuid, api_key_id: Uuid },
}
async fn enqueue_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
let mut conn = state.redis.clone();
let now = Utc::now().to_rfc3339();
redis::cmd("XADD")
.arg("stream:compress_jobs")
.arg("*")
.arg("task_id")
.arg(task_id.to_string())
.arg("created_at")
.arg(now)
.query_async::<_, redis::Value>(&mut conn)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "写入队列失败").with_source(err))?;
Ok(())
}
async fn parse_batch_request(
state: &AppState,
task_id: Uuid,
multipart: &mut Multipart,
) -> Result<(Vec<BatchFileInput>, BatchOptions, String), AppError> {
let mut files: Vec<BatchFileInput> = Vec::new();
let mut file_digests: Vec<String> = Vec::new();
let mut opts = BatchOptions {
level: CompressionLevel::Medium,
compression_rate: None,
output_format: None,
max_width: None,
max_height: None,
preserve_metadata: false,
};
let base_dir = format!("{}/orig/{task_id}", state.config.storage_path);
tokio::fs::create_dir_all(&base_dir)
.await
.map_err(|err| AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err))?;
loop {
let next = multipart.next_field().await.map_err(|err| {
AppError::new(ErrorCode::InvalidRequest, "读取上传内容失败").with_source(err)
});
let field = match next {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
};
let Some(field) = field else { break };
let name = field.name().unwrap_or("").to_string();
if name == "files" || name == "files[]" {
let file_id = Uuid::new_v4();
let original_name = field.file_name().unwrap_or("upload").to_string();
let bytes = match field.bytes().await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::InvalidRequest, "读取文件失败").with_source(err)
);
}
};
let original_size = bytes.len() as u64;
let file_digest = {
let mut h = Sha256::new();
h.update(&bytes);
h.update(original_name.as_bytes());
hex::encode(h.finalize())
};
let original_format = match compress::detect_format(&bytes) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
};
let output_format = opts.output_format.unwrap_or(original_format);
let path = format!("{base_dir}/{file_id}.{}", original_format.extension());
let mut f = match tokio::fs::File::create(&path).await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
);
}
};
if let Err(err) = f.write_all(&bytes).await {
let _ = tokio::fs::remove_file(&path).await;
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
);
}
files.push(BatchFileInput {
file_id,
original_name,
original_format,
output_format,
original_size,
storage_path: path,
});
file_digests.push(file_digest);
continue;
}
let text = match field.text().await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::InvalidRequest, "读取字段失败").with_source(err),
);
}
};
match name.as_str() {
"level" => {
opts.level = match compress::parse_level(&text) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
}
}
"output_format" => {
let v = text.trim();
if !v.is_empty() {
opts.output_format = Some(match compress::parse_output_format(v) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
});
}
}
"compression_rate" | "quality" => {
let v = text.trim();
if !v.is_empty() {
opts.compression_rate = Some(match compress::parse_compression_rate(v) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
});
}
}
"max_width" => {
let v = text.trim();
if !v.is_empty() {
opts.max_width = Some(match v.parse::<u32>() {
Ok(n) => n,
Err(_) => {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "max_width 格式错误"));
}
});
}
}
"max_height" => {
let v = text.trim();
if !v.is_empty() {
opts.max_height = Some(match v.parse::<u32>() {
Ok(n) => n,
Err(_) => {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "max_height 格式错误"));
}
});
}
}
"preserve_metadata" => {
opts.preserve_metadata = matches!(
text.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "y" | "on"
);
}
_ => {}
}
}
if let Some(rate) = opts.compression_rate {
opts.level = compress::rate_to_level(rate);
}
if opts.output_format.is_some() {
cleanup_file_paths(&files).await;
return Err(AppError::new(
ErrorCode::InvalidRequest,
"当前仅支持保持原图片格式",
));
}
let mw = opts.max_width.map(|v| v.to_string()).unwrap_or_default();
let mh = opts.max_height.map(|v| v.to_string()).unwrap_or_default();
let out_fmt = opts.output_format.map(|f| f.as_str()).unwrap_or("");
let rate_key = opts
.compression_rate
.map(|v| v.to_string())
.unwrap_or_default();
let preserve = if opts.preserve_metadata { "1" } else { "0" };
let mut h = Sha256::new();
h.update(b"compress_batch_v1");
h.update(opts.level.as_str().as_bytes());
h.update(out_fmt.as_bytes());
h.update(rate_key.as_bytes());
h.update(mw.as_bytes());
h.update(mh.as_bytes());
h.update(preserve.as_bytes());
for d in &file_digests {
h.update(d.as_bytes());
}
let request_hash = hex::encode(h.finalize());
Ok((files, opts, request_hash))
}
fn enforce_batch_limits_anonymous(state: &AppState, files: &[BatchFileInput]) -> Result<(), AppError> {
let max_files = state.config.anon_max_files_per_batch as usize;
if files.len() > max_files {
return Err(AppError::new(
ErrorCode::InvalidRequest,
format!("匿名试用单次最多 {} 个文件", max_files),
));
}
let max_bytes = state.config.anon_max_file_size_mb * 1024 * 1024;
for f in files {
if f.original_size > max_bytes {
return Err(AppError::new(
ErrorCode::FileTooLarge,
format!("匿名试用单文件最大 {} MB", state.config.anon_max_file_size_mb),
));
}
}
Ok(())
}
fn enforce_batch_limits_plan(plan: &Plan, files: &[BatchFileInput]) -> Result<(), AppError> {
let max_files = plan.max_files_per_batch as usize;
if files.len() > max_files {
return Err(AppError::new(
ErrorCode::InvalidRequest,
format!("当前套餐单次最多 {} 个文件", plan.max_files_per_batch),
));
}
let max_bytes = (plan.max_file_size_mb as u64) * 1024 * 1024;
for f in files {
if f.original_size > max_bytes {
return Err(AppError::new(
ErrorCode::FileTooLarge,
format!("当前套餐单文件最大 {} MB", plan.max_file_size_mb),
));
}
}
Ok(())
}
async fn ensure_quota_available(
state: &AppState,
ctx: &BillingContext,
needed_units: i32,
) -> Result<(), AppError> {
if needed_units <= 0 {
return Ok(());
}
#[derive(Debug, FromRow)]
struct UsageRow {
used_units: i32,
bonus_units: i32,
}
let usage = sqlx::query_as::<_, UsageRow>(
r#"
SELECT used_units, bonus_units
FROM usage_periods
WHERE user_id = $1 AND period_start = $2 AND period_end = $3
"#,
)
.bind(ctx.user_id)
.bind(ctx.period_start)
.bind(ctx.period_end)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询用量失败").with_source(err))?
.unwrap_or(UsageRow {
used_units: 0,
bonus_units: 0,
});
let total_units = ctx.plan.included_units_per_period + usage.bonus_units;
let remaining = total_units - usage.used_units;
if remaining < needed_units {
return Err(AppError::new(ErrorCode::QuotaExceeded, "当期配额已用完"));
}
Ok(())
}
async fn anonymous_remaining_units(
state: &AppState,
session_id: &str,
ip: IpAddr,
) -> Result<i64, AppError> {
let date = utc8_date();
let session_key = format!("anon_quota:{session_id}:{date}");
let ip_key = format!("anon_quota_ip:{ip}:{date}");
let mut conn = state.redis.clone();
let v1: Option<i64> = redis::cmd("GET")
.arg(session_key)
.query_async(&mut conn)
.await
.unwrap_or(None);
let v2: Option<i64> = redis::cmd("GET")
.arg(ip_key)
.query_async(&mut conn)
.await
.unwrap_or(None);
let limit = state.config.anon_daily_units as i64;
Ok(std::cmp::min(limit - v1.unwrap_or(0), limit - v2.unwrap_or(0)))
}
fn utc8_date() -> String {
let now = Utc::now() + Duration::hours(8);
now.format("%Y-%m-%d").to_string()
}
#[derive(Debug, FromRow)]
struct TaskRow {
id: Uuid,
status: String,
total_files: i32,
completed_files: i32,
failed_files: i32,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
expires_at: DateTime<Utc>,
user_id: Option<Uuid>,
session_id: Option<String>,
}
#[derive(Debug, FromRow)]
struct TaskFileRow {
id: Uuid,
original_name: String,
original_size: i64,
compressed_size: Option<i64>,
saved_percent: Option<f64>,
status: String,
}
#[derive(Debug, Serialize)]
struct TaskFileView {
file_id: Uuid,
original_name: String,
original_size: i64,
compressed_size: Option<i64>,
saved_percent: Option<f64>,
status: String,
download_url: Option<String>,
}
#[derive(Debug, Serialize)]
struct TaskView {
task_id: Uuid,
status: String,
progress: i32,
total_files: i32,
completed_files: i32,
failed_files: i32,
files: Vec<TaskFileView>,
download_all_url: String,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
expires_at: DateTime<Utc>,
}
async fn get_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<TaskView>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
r#"
SELECT
id,
status::text AS status,
total_files,
completed_files,
failed_files,
created_at,
completed_at,
expires_at,
user_id,
session_id
FROM tasks
WHERE id = $1
"#,
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
if task.expires_at <= Utc::now() {
return Err(AppError::new(ErrorCode::NotFound, "任务已过期或不存在"));
}
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
let files = sqlx::query_as::<_, TaskFileRow>(
r#"
SELECT
id,
original_name,
original_size,
compressed_size,
saved_percent::float8 AS saved_percent,
status::text AS status
FROM task_files
WHERE task_id = $1
ORDER BY created_at ASC
"#,
)
.bind(task_id)
.fetch_all(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务文件失败").with_source(err))?;
let file_views = files
.into_iter()
.map(|f| TaskFileView {
file_id: f.id,
original_name: f.original_name,
original_size: f.original_size,
compressed_size: f.compressed_size,
saved_percent: f.saved_percent,
status: f.status.clone(),
download_url: if f.status == "completed" {
Some(format!("/downloads/{}", f.id))
} else {
None
},
})
.collect::<Vec<_>>();
let processed = task.completed_files + task.failed_files;
let progress = if task.total_files <= 0 {
0
} else {
((processed as f64) * 100.0 / (task.total_files as f64))
.round()
.clamp(0.0, 100.0) as i32
};
Ok((
jar,
Json(Envelope {
success: true,
data: TaskView {
task_id,
status: task.status,
progress,
total_files: task.total_files,
completed_files: task.completed_files,
failed_files: task.failed_files,
files: file_views,
download_all_url: format!("/downloads/tasks/{task_id}"),
created_at: task.created_at,
completed_at: task.completed_at,
expires_at: task.expires_at,
},
}),
))
}
async fn cancel_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
if matches!(task.status.as_str(), "completed" | "failed" | "cancelled") {
return Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "任务已结束" }),
}),
));
}
let updated = sqlx::query(
"UPDATE tasks SET status = 'cancelled', completed_at = NOW() WHERE id = $1 AND status IN ('pending', 'processing')",
)
.bind(task_id)
.execute(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "取消任务失败").with_source(err))?;
if updated.rows_affected() == 0 {
return Err(AppError::new(ErrorCode::InvalidRequest, "任务状态不可取消"));
}
Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "已取消" }),
}),
))
}
async fn delete_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
if task.status == "processing" {
return Err(AppError::new(
ErrorCode::InvalidRequest,
"任务处理中,请先取消后再删除",
));
}
let paths: Vec<Option<String>> =
sqlx::query_scalar("SELECT storage_path FROM task_files WHERE task_id = $1")
.bind(task_id)
.fetch_all(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询文件失败").with_source(err))?;
for p in paths.into_iter().flatten() {
let _ = tokio::fs::remove_file(p).await;
}
if state.config.storage_type.to_ascii_lowercase() == "local" {
let zip_path = format!("{}/zips/{task_id}.zip", state.config.storage_path);
let _ = tokio::fs::remove_file(zip_path).await;
let orig_dir = format!("{}/orig/{task_id}", state.config.storage_path);
let _ = tokio::fs::remove_dir_all(orig_dir).await;
}
let deleted = sqlx::query("DELETE FROM tasks WHERE id = $1")
.bind(task_id)
.execute(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "删除任务失败").with_source(err))?;
if deleted.rows_affected() == 0 {
return Err(AppError::new(ErrorCode::NotFound, "任务不存在"));
}
Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "已删除" }),
}),
))
}
fn authorize_task(
principal: &context::Principal,
user_id: Option<Uuid>,
session_id: &str,
) -> Result<(), AppError> {
if let Some(owner) = user_id {
match principal {
context::Principal::User { user_id: me, .. } if *me == owner => Ok(()),
context::Principal::ApiKey { user_id: me, .. } if *me == owner => Ok(()),
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
}
} else {
match principal {
context::Principal::Anonymous { session_id: sid } if sid == session_id => Ok(()),
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
}
}
}
async fn cleanup_file_paths(files: &[BatchFileInput]) {
for f in files {
let _ = tokio::fs::remove_file(&f.storage_path).await;
}
}