Implement compression quota refunds and admin manual subscription

This commit is contained in:
2025-12-19 23:28:32 +08:00
commit 11f48fd3dd
106 changed files with 27848 additions and 0 deletions

987
src/api/tasks.rs Normal file
View File

@@ -0,0 +1,987 @@
use crate::api::context;
use crate::api::envelope::Envelope;
use crate::error::{AppError, ErrorCode};
use crate::services::billing;
use crate::services::billing::{BillingContext, Plan};
use crate::services::compress;
use crate::services::compress::{CompressionLevel, ImageFmt};
use crate::services::idempotency;
use crate::state::AppState;
use axum::extract::{ConnectInfo, Multipart, Path, State};
use axum::http::HeaderMap;
use axum::routing::{delete, get, post};
use axum::{Json, Router};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sqlx::FromRow;
use std::net::{IpAddr, SocketAddr};
use tokio::io::AsyncWriteExt;
use uuid::Uuid;
pub fn router() -> Router<AppState> {
Router::new()
.route("/compress/batch", post(create_batch_task))
.route("/compress/tasks/{task_id}", get(get_task))
.route("/compress/tasks/{task_id}/cancel", post(cancel_task))
.route("/compress/tasks/{task_id}", delete(delete_task))
}
#[derive(Debug, Serialize, Deserialize)]
struct BatchCreateResponse {
task_id: Uuid,
total_files: i32,
status: String,
status_url: String,
}
#[derive(Debug)]
struct BatchFileInput {
file_id: Uuid,
original_name: String,
original_format: ImageFmt,
output_format: ImageFmt,
original_size: u64,
storage_path: String,
}
#[derive(Debug)]
struct BatchOptions {
level: CompressionLevel,
compression_rate: Option<u8>,
output_format: Option<ImageFmt>,
max_width: Option<u32>,
max_height: Option<u32>,
preserve_metadata: bool,
}
async fn create_batch_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
mut multipart: Multipart,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<BatchCreateResponse>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
if state.config.storage_type.to_ascii_lowercase() != "local" {
return Err(AppError::new(
ErrorCode::StorageUnavailable,
"当前仅支持本地存储STORAGE_TYPE=local",
));
}
let idempotency_key = headers
.get("idempotency-key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(str::to_string);
let idempotency_scope = match &principal {
context::Principal::User { user_id, .. } => Some(idempotency::Scope::User(*user_id)),
context::Principal::ApiKey { api_key_id, .. } => Some(idempotency::Scope::ApiKey(*api_key_id)),
_ => None,
};
let task_id = Uuid::new_v4();
let (files, opts, request_hash) = parse_batch_request(&state, task_id, &mut multipart).await?;
if files.is_empty() {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "缺少 files[]"));
}
let mut idem_acquired = false;
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
match idempotency::begin(
&state,
scope,
idem_key,
&request_hash,
state.config.idempotency_ttl_hours as i64,
)
.await?
{
idempotency::BeginResult::Replay { response_body, .. } => {
cleanup_file_paths(&files).await;
let resp: BatchCreateResponse =
serde_json::from_value(response_body).map_err(|err| {
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
})?;
return Ok((jar, Json(Envelope { success: true, data: resp })));
}
idempotency::BeginResult::InProgress => {
cleanup_file_paths(&files).await;
if let Some((_status, body)) = idempotency::wait_for_replay(
&state,
scope,
idem_key,
&request_hash,
10_000,
)
.await?
{
let resp: BatchCreateResponse =
serde_json::from_value(body).map_err(|err| {
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
})?;
return Ok((jar, Json(Envelope { success: true, data: resp })));
}
return Err(AppError::new(
ErrorCode::InvalidRequest,
"请求正在处理中,请稍后重试",
));
}
idempotency::BeginResult::Acquired { .. } => {
idem_acquired = true;
}
}
}
let create_result: Result<BatchCreateResponse, AppError> = (async {
let (retention, task_owner, source) = match &principal {
context::Principal::Anonymous { session_id } => {
enforce_batch_limits_anonymous(&state, &files)?;
let remaining = anonymous_remaining_units(&state, session_id, ip).await?;
if remaining < files.len() as i64 {
return Err(AppError::new(
ErrorCode::QuotaExceeded,
"匿名试用次数已用完(每日 10 次)",
));
}
Ok((
Duration::hours(state.config.anon_retention_hours as i64),
TaskOwner::Anonymous {
session_id: session_id.clone(),
},
"web",
))
}
context::Principal::User {
user_id,
email_verified,
..
} => {
if !email_verified {
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
}
let billing = billing::get_user_billing(&state, *user_id).await?;
enforce_batch_limits_plan(&billing.plan, &files)?;
ensure_quota_available(&state, &billing, files.len() as i32).await?;
Ok((
Duration::days(billing.plan.retention_days as i64),
TaskOwner::User { user_id: *user_id },
"web",
))
}
context::Principal::ApiKey {
user_id,
api_key_id,
email_verified,
..
} => {
if !email_verified {
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
}
let billing = billing::get_user_billing(&state, *user_id).await?;
if !billing.plan.feature_api_enabled {
return Err(AppError::new(ErrorCode::Forbidden, "当前套餐未开通 API"));
}
enforce_batch_limits_plan(&billing.plan, &files)?;
ensure_quota_available(&state, &billing, files.len() as i32).await?;
Ok((
Duration::days(billing.plan.retention_days as i64),
TaskOwner::ApiKey {
user_id: *user_id,
api_key_id: *api_key_id,
},
"api",
))
}
}?;
tokio::fs::create_dir_all(&state.config.storage_path)
.await
.map_err(|err| {
AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err)
})?;
let expires_at = Utc::now() + retention;
let (user_id, session_id, api_key_id) = match &task_owner {
TaskOwner::Anonymous { session_id } => (None, Some(session_id.clone()), None),
TaskOwner::User { user_id } => (Some(*user_id), None, None),
TaskOwner::ApiKey { user_id, api_key_id } => (Some(*user_id), None, Some(*api_key_id)),
};
let total_original_size: i64 = files.iter().map(|f| f.original_size as i64).sum();
let mut tx = state
.db
.begin()
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "开启事务失败").with_source(err))?;
sqlx::query(
r#"
INSERT INTO tasks (
id, user_id, session_id, api_key_id, client_ip, source, status,
compression_level, output_format, max_width, max_height, preserve_metadata,
compression_rate,
total_files, completed_files, failed_files,
total_original_size, total_compressed_size,
expires_at
) VALUES (
$1, $2, $3, $4, $5::inet, $6::task_source, 'pending',
$7::compression_level, $8, $9, $10, $11, $12,
$13, 0, 0,
$14, 0,
$15
)
"#,
)
.bind(task_id)
.bind(user_id)
.bind(session_id)
.bind(api_key_id)
.bind(ip.to_string())
.bind(source)
.bind(opts.level.as_str())
.bind(opts.output_format.map(|f| f.as_str()))
.bind(opts.max_width.map(|v| v as i32))
.bind(opts.max_height.map(|v| v as i32))
.bind(false)
.bind(opts.compression_rate.map(|v| v as i16))
.bind(files.len() as i32)
.bind(total_original_size)
.bind(expires_at)
.execute(&mut *tx)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "创建任务失败").with_source(err))?;
for file in &files {
sqlx::query(
r#"
INSERT INTO task_files (
id, task_id,
original_name, original_format, output_format,
original_size,
storage_path, status
) VALUES (
$1, $2,
$3, $4, $5,
$6,
$7, 'pending'
)
"#,
)
.bind(file.file_id)
.bind(task_id)
.bind(&file.original_name)
.bind(file.original_format.as_str())
.bind(file.output_format.as_str())
.bind(file.original_size as i64)
.bind(&file.storage_path)
.execute(&mut *tx)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "创建文件记录失败").with_source(err))?;
}
tx.commit()
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "提交事务失败").with_source(err))?;
if let Err(err) = enqueue_task(&state, task_id).await {
let _ = sqlx::query("UPDATE tasks SET status = 'failed', error_message = $2 WHERE id = $1")
.bind(task_id)
.bind("队列提交失败")
.execute(&state.db)
.await;
return Err(err);
}
Ok(BatchCreateResponse {
task_id,
total_files: files.len() as i32,
status: "pending".to_string(),
status_url: format!("/api/v1/compress/tasks/{task_id}"),
})
})
.await;
match create_result {
Ok(resp) => {
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
if idem_acquired {
let _ = idempotency::complete(
&state,
scope,
idem_key,
&request_hash,
200,
serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null),
)
.await;
}
}
Ok((jar, Json(Envelope { success: true, data: resp })))
}
Err(err) => {
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
if idem_acquired {
let _ = idempotency::abort(&state, scope, idem_key, &request_hash).await;
}
}
cleanup_file_paths(&files).await;
Err(err)
}
}
}
#[derive(Debug)]
enum TaskOwner {
Anonymous { session_id: String },
User { user_id: Uuid },
ApiKey { user_id: Uuid, api_key_id: Uuid },
}
async fn enqueue_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
let mut conn = state.redis.clone();
let now = Utc::now().to_rfc3339();
redis::cmd("XADD")
.arg("stream:compress_jobs")
.arg("*")
.arg("task_id")
.arg(task_id.to_string())
.arg("created_at")
.arg(now)
.query_async::<_, redis::Value>(&mut conn)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "写入队列失败").with_source(err))?;
Ok(())
}
async fn parse_batch_request(
state: &AppState,
task_id: Uuid,
multipart: &mut Multipart,
) -> Result<(Vec<BatchFileInput>, BatchOptions, String), AppError> {
let mut files: Vec<BatchFileInput> = Vec::new();
let mut file_digests: Vec<String> = Vec::new();
let mut opts = BatchOptions {
level: CompressionLevel::Medium,
compression_rate: None,
output_format: None,
max_width: None,
max_height: None,
preserve_metadata: false,
};
let base_dir = format!("{}/orig/{task_id}", state.config.storage_path);
tokio::fs::create_dir_all(&base_dir)
.await
.map_err(|err| AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err))?;
loop {
let next = multipart.next_field().await.map_err(|err| {
AppError::new(ErrorCode::InvalidRequest, "读取上传内容失败").with_source(err)
});
let field = match next {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
};
let Some(field) = field else { break };
let name = field.name().unwrap_or("").to_string();
if name == "files" || name == "files[]" {
let file_id = Uuid::new_v4();
let original_name = field.file_name().unwrap_or("upload").to_string();
let bytes = match field.bytes().await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::InvalidRequest, "读取文件失败").with_source(err)
);
}
};
let original_size = bytes.len() as u64;
let file_digest = {
let mut h = Sha256::new();
h.update(&bytes);
h.update(original_name.as_bytes());
hex::encode(h.finalize())
};
let original_format = match compress::detect_format(&bytes) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
};
let output_format = opts.output_format.unwrap_or(original_format);
let path = format!("{base_dir}/{file_id}.{}", original_format.extension());
let mut f = match tokio::fs::File::create(&path).await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
);
}
};
if let Err(err) = f.write_all(&bytes).await {
let _ = tokio::fs::remove_file(&path).await;
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
);
}
files.push(BatchFileInput {
file_id,
original_name,
original_format,
output_format,
original_size,
storage_path: path,
});
file_digests.push(file_digest);
continue;
}
let text = match field.text().await {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(
AppError::new(ErrorCode::InvalidRequest, "读取字段失败").with_source(err),
);
}
};
match name.as_str() {
"level" => {
opts.level = match compress::parse_level(&text) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
}
}
"output_format" => {
let v = text.trim();
if !v.is_empty() {
opts.output_format = Some(match compress::parse_output_format(v) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
});
}
}
"compression_rate" | "quality" => {
let v = text.trim();
if !v.is_empty() {
opts.compression_rate = Some(match compress::parse_compression_rate(v) {
Ok(v) => v,
Err(err) => {
cleanup_file_paths(&files).await;
return Err(err);
}
});
}
}
"max_width" => {
let v = text.trim();
if !v.is_empty() {
opts.max_width = Some(match v.parse::<u32>() {
Ok(n) => n,
Err(_) => {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "max_width 格式错误"));
}
});
}
}
"max_height" => {
let v = text.trim();
if !v.is_empty() {
opts.max_height = Some(match v.parse::<u32>() {
Ok(n) => n,
Err(_) => {
cleanup_file_paths(&files).await;
return Err(AppError::new(ErrorCode::InvalidRequest, "max_height 格式错误"));
}
});
}
}
"preserve_metadata" => {
opts.preserve_metadata = matches!(
text.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "y" | "on"
);
}
_ => {}
}
}
if let Some(rate) = opts.compression_rate {
opts.level = compress::rate_to_level(rate);
}
if opts.output_format.is_some() {
cleanup_file_paths(&files).await;
return Err(AppError::new(
ErrorCode::InvalidRequest,
"当前仅支持保持原图片格式",
));
}
let mw = opts.max_width.map(|v| v.to_string()).unwrap_or_default();
let mh = opts.max_height.map(|v| v.to_string()).unwrap_or_default();
let out_fmt = opts.output_format.map(|f| f.as_str()).unwrap_or("");
let rate_key = opts
.compression_rate
.map(|v| v.to_string())
.unwrap_or_default();
let preserve = if opts.preserve_metadata { "1" } else { "0" };
let mut h = Sha256::new();
h.update(b"compress_batch_v1");
h.update(opts.level.as_str().as_bytes());
h.update(out_fmt.as_bytes());
h.update(rate_key.as_bytes());
h.update(mw.as_bytes());
h.update(mh.as_bytes());
h.update(preserve.as_bytes());
for d in &file_digests {
h.update(d.as_bytes());
}
let request_hash = hex::encode(h.finalize());
Ok((files, opts, request_hash))
}
fn enforce_batch_limits_anonymous(state: &AppState, files: &[BatchFileInput]) -> Result<(), AppError> {
let max_files = state.config.anon_max_files_per_batch as usize;
if files.len() > max_files {
return Err(AppError::new(
ErrorCode::InvalidRequest,
format!("匿名试用单次最多 {} 个文件", max_files),
));
}
let max_bytes = state.config.anon_max_file_size_mb * 1024 * 1024;
for f in files {
if f.original_size > max_bytes {
return Err(AppError::new(
ErrorCode::FileTooLarge,
format!("匿名试用单文件最大 {} MB", state.config.anon_max_file_size_mb),
));
}
}
Ok(())
}
fn enforce_batch_limits_plan(plan: &Plan, files: &[BatchFileInput]) -> Result<(), AppError> {
let max_files = plan.max_files_per_batch as usize;
if files.len() > max_files {
return Err(AppError::new(
ErrorCode::InvalidRequest,
format!("当前套餐单次最多 {} 个文件", plan.max_files_per_batch),
));
}
let max_bytes = (plan.max_file_size_mb as u64) * 1024 * 1024;
for f in files {
if f.original_size > max_bytes {
return Err(AppError::new(
ErrorCode::FileTooLarge,
format!("当前套餐单文件最大 {} MB", plan.max_file_size_mb),
));
}
}
Ok(())
}
async fn ensure_quota_available(
state: &AppState,
ctx: &BillingContext,
needed_units: i32,
) -> Result<(), AppError> {
if needed_units <= 0 {
return Ok(());
}
#[derive(Debug, FromRow)]
struct UsageRow {
used_units: i32,
bonus_units: i32,
}
let usage = sqlx::query_as::<_, UsageRow>(
r#"
SELECT used_units, bonus_units
FROM usage_periods
WHERE user_id = $1 AND period_start = $2 AND period_end = $3
"#,
)
.bind(ctx.user_id)
.bind(ctx.period_start)
.bind(ctx.period_end)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询用量失败").with_source(err))?
.unwrap_or(UsageRow {
used_units: 0,
bonus_units: 0,
});
let total_units = ctx.plan.included_units_per_period + usage.bonus_units;
let remaining = total_units - usage.used_units;
if remaining < needed_units {
return Err(AppError::new(ErrorCode::QuotaExceeded, "当期配额已用完"));
}
Ok(())
}
async fn anonymous_remaining_units(
state: &AppState,
session_id: &str,
ip: IpAddr,
) -> Result<i64, AppError> {
let date = utc8_date();
let session_key = format!("anon_quota:{session_id}:{date}");
let ip_key = format!("anon_quota_ip:{ip}:{date}");
let mut conn = state.redis.clone();
let v1: Option<i64> = redis::cmd("GET")
.arg(session_key)
.query_async(&mut conn)
.await
.unwrap_or(None);
let v2: Option<i64> = redis::cmd("GET")
.arg(ip_key)
.query_async(&mut conn)
.await
.unwrap_or(None);
let limit = state.config.anon_daily_units as i64;
Ok(std::cmp::min(limit - v1.unwrap_or(0), limit - v2.unwrap_or(0)))
}
fn utc8_date() -> String {
let now = Utc::now() + Duration::hours(8);
now.format("%Y-%m-%d").to_string()
}
#[derive(Debug, FromRow)]
struct TaskRow {
id: Uuid,
status: String,
total_files: i32,
completed_files: i32,
failed_files: i32,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
expires_at: DateTime<Utc>,
user_id: Option<Uuid>,
session_id: Option<String>,
}
#[derive(Debug, FromRow)]
struct TaskFileRow {
id: Uuid,
original_name: String,
original_size: i64,
compressed_size: Option<i64>,
saved_percent: Option<f64>,
status: String,
}
#[derive(Debug, Serialize)]
struct TaskFileView {
file_id: Uuid,
original_name: String,
original_size: i64,
compressed_size: Option<i64>,
saved_percent: Option<f64>,
status: String,
download_url: Option<String>,
}
#[derive(Debug, Serialize)]
struct TaskView {
task_id: Uuid,
status: String,
progress: i32,
total_files: i32,
completed_files: i32,
failed_files: i32,
files: Vec<TaskFileView>,
download_all_url: String,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
expires_at: DateTime<Utc>,
}
async fn get_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<TaskView>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
r#"
SELECT
id,
status::text AS status,
total_files,
completed_files,
failed_files,
created_at,
completed_at,
expires_at,
user_id,
session_id
FROM tasks
WHERE id = $1
"#,
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
if task.expires_at <= Utc::now() {
return Err(AppError::new(ErrorCode::NotFound, "任务已过期或不存在"));
}
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
let files = sqlx::query_as::<_, TaskFileRow>(
r#"
SELECT
id,
original_name,
original_size,
compressed_size,
saved_percent::float8 AS saved_percent,
status::text AS status
FROM task_files
WHERE task_id = $1
ORDER BY created_at ASC
"#,
)
.bind(task_id)
.fetch_all(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务文件失败").with_source(err))?;
let file_views = files
.into_iter()
.map(|f| TaskFileView {
file_id: f.id,
original_name: f.original_name,
original_size: f.original_size,
compressed_size: f.compressed_size,
saved_percent: f.saved_percent,
status: f.status.clone(),
download_url: if f.status == "completed" {
Some(format!("/downloads/{}", f.id))
} else {
None
},
})
.collect::<Vec<_>>();
let processed = task.completed_files + task.failed_files;
let progress = if task.total_files <= 0 {
0
} else {
((processed as f64) * 100.0 / (task.total_files as f64))
.round()
.clamp(0.0, 100.0) as i32
};
Ok((
jar,
Json(Envelope {
success: true,
data: TaskView {
task_id,
status: task.status,
progress,
total_files: task.total_files,
completed_files: task.completed_files,
failed_files: task.failed_files,
files: file_views,
download_all_url: format!("/downloads/tasks/{task_id}"),
created_at: task.created_at,
completed_at: task.completed_at,
expires_at: task.expires_at,
},
}),
))
}
async fn cancel_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
if matches!(task.status.as_str(), "completed" | "failed" | "cancelled") {
return Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "任务已结束" }),
}),
));
}
let updated = sqlx::query(
"UPDATE tasks SET status = 'cancelled', completed_at = NOW() WHERE id = $1 AND status IN ('pending', 'processing')",
)
.bind(task_id)
.execute(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "取消任务失败").with_source(err))?;
if updated.rows_affected() == 0 {
return Err(AppError::new(ErrorCode::InvalidRequest, "任务状态不可取消"));
}
Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "已取消" }),
}),
))
}
async fn delete_task(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Path(task_id): Path<Uuid>,
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
let ip = context::client_ip(&headers, addr.ip());
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
let task = sqlx::query_as::<_, TaskRow>(
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
)
.bind(task_id)
.fetch_optional(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
if task.status == "processing" {
return Err(AppError::new(
ErrorCode::InvalidRequest,
"任务处理中,请先取消后再删除",
));
}
let paths: Vec<Option<String>> =
sqlx::query_scalar("SELECT storage_path FROM task_files WHERE task_id = $1")
.bind(task_id)
.fetch_all(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "查询文件失败").with_source(err))?;
for p in paths.into_iter().flatten() {
let _ = tokio::fs::remove_file(p).await;
}
if state.config.storage_type.to_ascii_lowercase() == "local" {
let zip_path = format!("{}/zips/{task_id}.zip", state.config.storage_path);
let _ = tokio::fs::remove_file(zip_path).await;
let orig_dir = format!("{}/orig/{task_id}", state.config.storage_path);
let _ = tokio::fs::remove_dir_all(orig_dir).await;
}
let deleted = sqlx::query("DELETE FROM tasks WHERE id = $1")
.bind(task_id)
.execute(&state.db)
.await
.map_err(|err| AppError::new(ErrorCode::Internal, "删除任务失败").with_source(err))?;
if deleted.rows_affected() == 0 {
return Err(AppError::new(ErrorCode::NotFound, "任务不存在"));
}
Ok((
jar,
Json(Envelope {
success: true,
data: serde_json::json!({ "message": "已删除" }),
}),
))
}
fn authorize_task(
principal: &context::Principal,
user_id: Option<Uuid>,
session_id: &str,
) -> Result<(), AppError> {
if let Some(owner) = user_id {
match principal {
context::Principal::User { user_id: me, .. } if *me == owner => Ok(()),
context::Principal::ApiKey { user_id: me, .. } if *me == owner => Ok(()),
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
}
} else {
match principal {
context::Principal::Anonymous { session_id: sid } if sid == session_id => Ok(()),
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
}
}
}
async fn cleanup_file_paths(files: &[BatchFileInput]) {
for f in files {
let _ = tokio::fs::remove_file(&f.storage_path).await;
}
}