988 lines
33 KiB
Rust
988 lines
33 KiB
Rust
use crate::api::context;
|
||
use crate::api::envelope::Envelope;
|
||
use crate::error::{AppError, ErrorCode};
|
||
use crate::services::billing;
|
||
use crate::services::billing::{BillingContext, Plan};
|
||
use crate::services::compress;
|
||
use crate::services::compress::{CompressionLevel, ImageFmt};
|
||
use crate::services::idempotency;
|
||
use crate::state::AppState;
|
||
|
||
use axum::extract::{ConnectInfo, Multipart, Path, State};
|
||
use axum::http::HeaderMap;
|
||
use axum::routing::{delete, get, post};
|
||
use axum::{Json, Router};
|
||
use chrono::{DateTime, Duration, Utc};
|
||
use serde::{Deserialize, Serialize};
|
||
use sha2::{Digest, Sha256};
|
||
use sqlx::FromRow;
|
||
use std::net::{IpAddr, SocketAddr};
|
||
use tokio::io::AsyncWriteExt;
|
||
use uuid::Uuid;
|
||
|
||
pub fn router() -> Router<AppState> {
|
||
Router::new()
|
||
.route("/compress/batch", post(create_batch_task))
|
||
.route("/compress/tasks/{task_id}", get(get_task))
|
||
.route("/compress/tasks/{task_id}/cancel", post(cancel_task))
|
||
.route("/compress/tasks/{task_id}", delete(delete_task))
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct BatchCreateResponse {
|
||
task_id: Uuid,
|
||
total_files: i32,
|
||
status: String,
|
||
status_url: String,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct BatchFileInput {
|
||
file_id: Uuid,
|
||
original_name: String,
|
||
original_format: ImageFmt,
|
||
output_format: ImageFmt,
|
||
original_size: u64,
|
||
storage_path: String,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct BatchOptions {
|
||
level: CompressionLevel,
|
||
compression_rate: Option<u8>,
|
||
output_format: Option<ImageFmt>,
|
||
max_width: Option<u32>,
|
||
max_height: Option<u32>,
|
||
preserve_metadata: bool,
|
||
}
|
||
|
||
async fn create_batch_task(
|
||
State(state): State<AppState>,
|
||
jar: axum_extra::extract::cookie::CookieJar,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
headers: HeaderMap,
|
||
mut multipart: Multipart,
|
||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<BatchCreateResponse>>), AppError> {
|
||
let ip = context::client_ip(&headers, addr.ip());
|
||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||
|
||
if state.config.storage_type.to_ascii_lowercase() != "local" {
|
||
return Err(AppError::new(
|
||
ErrorCode::StorageUnavailable,
|
||
"当前仅支持本地存储(STORAGE_TYPE=local)",
|
||
));
|
||
}
|
||
|
||
let idempotency_key = headers
|
||
.get("idempotency-key")
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(str::trim)
|
||
.filter(|v| !v.is_empty())
|
||
.map(str::to_string);
|
||
let idempotency_scope = match &principal {
|
||
context::Principal::User { user_id, .. } => Some(idempotency::Scope::User(*user_id)),
|
||
context::Principal::ApiKey { api_key_id, .. } => Some(idempotency::Scope::ApiKey(*api_key_id)),
|
||
_ => None,
|
||
};
|
||
|
||
let task_id = Uuid::new_v4();
|
||
let (files, opts, request_hash) = parse_batch_request(&state, task_id, &mut multipart).await?;
|
||
|
||
if files.is_empty() {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(AppError::new(ErrorCode::InvalidRequest, "缺少 files[]"));
|
||
}
|
||
|
||
let mut idem_acquired = false;
|
||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||
match idempotency::begin(
|
||
&state,
|
||
scope,
|
||
idem_key,
|
||
&request_hash,
|
||
state.config.idempotency_ttl_hours as i64,
|
||
)
|
||
.await?
|
||
{
|
||
idempotency::BeginResult::Replay { response_body, .. } => {
|
||
cleanup_file_paths(&files).await;
|
||
let resp: BatchCreateResponse =
|
||
serde_json::from_value(response_body).map_err(|err| {
|
||
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
|
||
})?;
|
||
return Ok((jar, Json(Envelope { success: true, data: resp })));
|
||
}
|
||
idempotency::BeginResult::InProgress => {
|
||
cleanup_file_paths(&files).await;
|
||
if let Some((_status, body)) = idempotency::wait_for_replay(
|
||
&state,
|
||
scope,
|
||
idem_key,
|
||
&request_hash,
|
||
10_000,
|
||
)
|
||
.await?
|
||
{
|
||
let resp: BatchCreateResponse =
|
||
serde_json::from_value(body).map_err(|err| {
|
||
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
|
||
})?;
|
||
return Ok((jar, Json(Envelope { success: true, data: resp })));
|
||
}
|
||
return Err(AppError::new(
|
||
ErrorCode::InvalidRequest,
|
||
"请求正在处理中,请稍后重试",
|
||
));
|
||
}
|
||
idempotency::BeginResult::Acquired { .. } => {
|
||
idem_acquired = true;
|
||
}
|
||
}
|
||
}
|
||
|
||
let create_result: Result<BatchCreateResponse, AppError> = (async {
|
||
let (retention, task_owner, source) = match &principal {
|
||
context::Principal::Anonymous { session_id } => {
|
||
enforce_batch_limits_anonymous(&state, &files)?;
|
||
let remaining = anonymous_remaining_units(&state, session_id, ip).await?;
|
||
if remaining < files.len() as i64 {
|
||
return Err(AppError::new(
|
||
ErrorCode::QuotaExceeded,
|
||
"匿名试用次数已用完(每日 10 次)",
|
||
));
|
||
}
|
||
Ok((
|
||
Duration::hours(state.config.anon_retention_hours as i64),
|
||
TaskOwner::Anonymous {
|
||
session_id: session_id.clone(),
|
||
},
|
||
"web",
|
||
))
|
||
}
|
||
context::Principal::User {
|
||
user_id,
|
||
email_verified,
|
||
..
|
||
} => {
|
||
if !email_verified {
|
||
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
|
||
}
|
||
let billing = billing::get_user_billing(&state, *user_id).await?;
|
||
enforce_batch_limits_plan(&billing.plan, &files)?;
|
||
ensure_quota_available(&state, &billing, files.len() as i32).await?;
|
||
Ok((
|
||
Duration::days(billing.plan.retention_days as i64),
|
||
TaskOwner::User { user_id: *user_id },
|
||
"web",
|
||
))
|
||
}
|
||
context::Principal::ApiKey {
|
||
user_id,
|
||
api_key_id,
|
||
email_verified,
|
||
..
|
||
} => {
|
||
if !email_verified {
|
||
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
|
||
}
|
||
let billing = billing::get_user_billing(&state, *user_id).await?;
|
||
if !billing.plan.feature_api_enabled {
|
||
return Err(AppError::new(ErrorCode::Forbidden, "当前套餐未开通 API"));
|
||
}
|
||
enforce_batch_limits_plan(&billing.plan, &files)?;
|
||
ensure_quota_available(&state, &billing, files.len() as i32).await?;
|
||
Ok((
|
||
Duration::days(billing.plan.retention_days as i64),
|
||
TaskOwner::ApiKey {
|
||
user_id: *user_id,
|
||
api_key_id: *api_key_id,
|
||
},
|
||
"api",
|
||
))
|
||
}
|
||
}?;
|
||
|
||
tokio::fs::create_dir_all(&state.config.storage_path)
|
||
.await
|
||
.map_err(|err| {
|
||
AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err)
|
||
})?;
|
||
|
||
let expires_at = Utc::now() + retention;
|
||
let (user_id, session_id, api_key_id) = match &task_owner {
|
||
TaskOwner::Anonymous { session_id } => (None, Some(session_id.clone()), None),
|
||
TaskOwner::User { user_id } => (Some(*user_id), None, None),
|
||
TaskOwner::ApiKey { user_id, api_key_id } => (Some(*user_id), None, Some(*api_key_id)),
|
||
};
|
||
|
||
let total_original_size: i64 = files.iter().map(|f| f.original_size as i64).sum();
|
||
|
||
let mut tx = state
|
||
.db
|
||
.begin()
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "开启事务失败").with_source(err))?;
|
||
|
||
sqlx::query(
|
||
r#"
|
||
INSERT INTO tasks (
|
||
id, user_id, session_id, api_key_id, client_ip, source, status,
|
||
compression_level, output_format, max_width, max_height, preserve_metadata,
|
||
compression_rate,
|
||
total_files, completed_files, failed_files,
|
||
total_original_size, total_compressed_size,
|
||
expires_at
|
||
) VALUES (
|
||
$1, $2, $3, $4, $5::inet, $6::task_source, 'pending',
|
||
$7::compression_level, $8, $9, $10, $11, $12,
|
||
$13, 0, 0,
|
||
$14, 0,
|
||
$15
|
||
)
|
||
"#,
|
||
)
|
||
.bind(task_id)
|
||
.bind(user_id)
|
||
.bind(session_id)
|
||
.bind(api_key_id)
|
||
.bind(ip.to_string())
|
||
.bind(source)
|
||
.bind(opts.level.as_str())
|
||
.bind(opts.output_format.map(|f| f.as_str()))
|
||
.bind(opts.max_width.map(|v| v as i32))
|
||
.bind(opts.max_height.map(|v| v as i32))
|
||
.bind(false)
|
||
.bind(opts.compression_rate.map(|v| v as i16))
|
||
.bind(files.len() as i32)
|
||
.bind(total_original_size)
|
||
.bind(expires_at)
|
||
.execute(&mut *tx)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "创建任务失败").with_source(err))?;
|
||
|
||
for file in &files {
|
||
sqlx::query(
|
||
r#"
|
||
INSERT INTO task_files (
|
||
id, task_id,
|
||
original_name, original_format, output_format,
|
||
original_size,
|
||
storage_path, status
|
||
) VALUES (
|
||
$1, $2,
|
||
$3, $4, $5,
|
||
$6,
|
||
$7, 'pending'
|
||
)
|
||
"#,
|
||
)
|
||
.bind(file.file_id)
|
||
.bind(task_id)
|
||
.bind(&file.original_name)
|
||
.bind(file.original_format.as_str())
|
||
.bind(file.output_format.as_str())
|
||
.bind(file.original_size as i64)
|
||
.bind(&file.storage_path)
|
||
.execute(&mut *tx)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "创建文件记录失败").with_source(err))?;
|
||
}
|
||
|
||
tx.commit()
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "提交事务失败").with_source(err))?;
|
||
|
||
if let Err(err) = enqueue_task(&state, task_id).await {
|
||
let _ = sqlx::query("UPDATE tasks SET status = 'failed', error_message = $2 WHERE id = $1")
|
||
.bind(task_id)
|
||
.bind("队列提交失败")
|
||
.execute(&state.db)
|
||
.await;
|
||
return Err(err);
|
||
}
|
||
|
||
Ok(BatchCreateResponse {
|
||
task_id,
|
||
total_files: files.len() as i32,
|
||
status: "pending".to_string(),
|
||
status_url: format!("/api/v1/compress/tasks/{task_id}"),
|
||
})
|
||
})
|
||
.await;
|
||
|
||
match create_result {
|
||
Ok(resp) => {
|
||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||
if idem_acquired {
|
||
let _ = idempotency::complete(
|
||
&state,
|
||
scope,
|
||
idem_key,
|
||
&request_hash,
|
||
200,
|
||
serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
Ok((jar, Json(Envelope { success: true, data: resp })))
|
||
}
|
||
Err(err) => {
|
||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||
if idem_acquired {
|
||
let _ = idempotency::abort(&state, scope, idem_key, &request_hash).await;
|
||
}
|
||
}
|
||
cleanup_file_paths(&files).await;
|
||
Err(err)
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
enum TaskOwner {
|
||
Anonymous { session_id: String },
|
||
User { user_id: Uuid },
|
||
ApiKey { user_id: Uuid, api_key_id: Uuid },
|
||
}
|
||
|
||
async fn enqueue_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
|
||
let mut conn = state.redis.clone();
|
||
let now = Utc::now().to_rfc3339();
|
||
redis::cmd("XADD")
|
||
.arg("stream:compress_jobs")
|
||
.arg("*")
|
||
.arg("task_id")
|
||
.arg(task_id.to_string())
|
||
.arg("created_at")
|
||
.arg(now)
|
||
.query_async::<_, redis::Value>(&mut conn)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "写入队列失败").with_source(err))?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn parse_batch_request(
|
||
state: &AppState,
|
||
task_id: Uuid,
|
||
multipart: &mut Multipart,
|
||
) -> Result<(Vec<BatchFileInput>, BatchOptions, String), AppError> {
|
||
let mut files: Vec<BatchFileInput> = Vec::new();
|
||
let mut file_digests: Vec<String> = Vec::new();
|
||
let mut opts = BatchOptions {
|
||
level: CompressionLevel::Medium,
|
||
compression_rate: None,
|
||
output_format: None,
|
||
max_width: None,
|
||
max_height: None,
|
||
preserve_metadata: false,
|
||
};
|
||
|
||
let base_dir = format!("{}/orig/{task_id}", state.config.storage_path);
|
||
tokio::fs::create_dir_all(&base_dir)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err))?;
|
||
|
||
loop {
|
||
let next = multipart.next_field().await.map_err(|err| {
|
||
AppError::new(ErrorCode::InvalidRequest, "读取上传内容失败").with_source(err)
|
||
});
|
||
|
||
let field = match next {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(err);
|
||
}
|
||
};
|
||
|
||
let Some(field) = field else { break };
|
||
|
||
let name = field.name().unwrap_or("").to_string();
|
||
if name == "files" || name == "files[]" {
|
||
let file_id = Uuid::new_v4();
|
||
let original_name = field.file_name().unwrap_or("upload").to_string();
|
||
let bytes = match field.bytes().await {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(
|
||
AppError::new(ErrorCode::InvalidRequest, "读取文件失败").with_source(err)
|
||
);
|
||
}
|
||
};
|
||
|
||
let original_size = bytes.len() as u64;
|
||
let file_digest = {
|
||
let mut h = Sha256::new();
|
||
h.update(&bytes);
|
||
h.update(original_name.as_bytes());
|
||
hex::encode(h.finalize())
|
||
};
|
||
let original_format = match compress::detect_format(&bytes) {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(err);
|
||
}
|
||
};
|
||
|
||
let output_format = opts.output_format.unwrap_or(original_format);
|
||
let path = format!("{base_dir}/{file_id}.{}", original_format.extension());
|
||
|
||
let mut f = match tokio::fs::File::create(&path).await {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(
|
||
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
|
||
);
|
||
}
|
||
};
|
||
if let Err(err) = f.write_all(&bytes).await {
|
||
let _ = tokio::fs::remove_file(&path).await;
|
||
cleanup_file_paths(&files).await;
|
||
return Err(
|
||
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
|
||
);
|
||
}
|
||
|
||
files.push(BatchFileInput {
|
||
file_id,
|
||
original_name,
|
||
original_format,
|
||
output_format,
|
||
original_size,
|
||
storage_path: path,
|
||
});
|
||
file_digests.push(file_digest);
|
||
continue;
|
||
}
|
||
|
||
let text = match field.text().await {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(
|
||
AppError::new(ErrorCode::InvalidRequest, "读取字段失败").with_source(err),
|
||
);
|
||
}
|
||
};
|
||
match name.as_str() {
|
||
"level" => {
|
||
opts.level = match compress::parse_level(&text) {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(err);
|
||
}
|
||
}
|
||
}
|
||
"output_format" => {
|
||
let v = text.trim();
|
||
if !v.is_empty() {
|
||
opts.output_format = Some(match compress::parse_output_format(v) {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(err);
|
||
}
|
||
});
|
||
}
|
||
}
|
||
"compression_rate" | "quality" => {
|
||
let v = text.trim();
|
||
if !v.is_empty() {
|
||
opts.compression_rate = Some(match compress::parse_compression_rate(v) {
|
||
Ok(v) => v,
|
||
Err(err) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(err);
|
||
}
|
||
});
|
||
}
|
||
}
|
||
"max_width" => {
|
||
let v = text.trim();
|
||
if !v.is_empty() {
|
||
opts.max_width = Some(match v.parse::<u32>() {
|
||
Ok(n) => n,
|
||
Err(_) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(AppError::new(ErrorCode::InvalidRequest, "max_width 格式错误"));
|
||
}
|
||
});
|
||
}
|
||
}
|
||
"max_height" => {
|
||
let v = text.trim();
|
||
if !v.is_empty() {
|
||
opts.max_height = Some(match v.parse::<u32>() {
|
||
Ok(n) => n,
|
||
Err(_) => {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(AppError::new(ErrorCode::InvalidRequest, "max_height 格式错误"));
|
||
}
|
||
});
|
||
}
|
||
}
|
||
"preserve_metadata" => {
|
||
opts.preserve_metadata = matches!(
|
||
text.trim().to_ascii_lowercase().as_str(),
|
||
"1" | "true" | "yes" | "y" | "on"
|
||
);
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
if let Some(rate) = opts.compression_rate {
|
||
opts.level = compress::rate_to_level(rate);
|
||
}
|
||
|
||
if opts.output_format.is_some() {
|
||
cleanup_file_paths(&files).await;
|
||
return Err(AppError::new(
|
||
ErrorCode::InvalidRequest,
|
||
"当前仅支持保持原图片格式",
|
||
));
|
||
}
|
||
|
||
let mw = opts.max_width.map(|v| v.to_string()).unwrap_or_default();
|
||
let mh = opts.max_height.map(|v| v.to_string()).unwrap_or_default();
|
||
let out_fmt = opts.output_format.map(|f| f.as_str()).unwrap_or("");
|
||
let rate_key = opts
|
||
.compression_rate
|
||
.map(|v| v.to_string())
|
||
.unwrap_or_default();
|
||
let preserve = if opts.preserve_metadata { "1" } else { "0" };
|
||
|
||
let mut h = Sha256::new();
|
||
h.update(b"compress_batch_v1");
|
||
h.update(opts.level.as_str().as_bytes());
|
||
h.update(out_fmt.as_bytes());
|
||
h.update(rate_key.as_bytes());
|
||
h.update(mw.as_bytes());
|
||
h.update(mh.as_bytes());
|
||
h.update(preserve.as_bytes());
|
||
for d in &file_digests {
|
||
h.update(d.as_bytes());
|
||
}
|
||
let request_hash = hex::encode(h.finalize());
|
||
|
||
Ok((files, opts, request_hash))
|
||
}
|
||
|
||
fn enforce_batch_limits_anonymous(state: &AppState, files: &[BatchFileInput]) -> Result<(), AppError> {
|
||
let max_files = state.config.anon_max_files_per_batch as usize;
|
||
if files.len() > max_files {
|
||
return Err(AppError::new(
|
||
ErrorCode::InvalidRequest,
|
||
format!("匿名试用单次最多 {} 个文件", max_files),
|
||
));
|
||
}
|
||
|
||
let max_bytes = state.config.anon_max_file_size_mb * 1024 * 1024;
|
||
for f in files {
|
||
if f.original_size > max_bytes {
|
||
return Err(AppError::new(
|
||
ErrorCode::FileTooLarge,
|
||
format!("匿名试用单文件最大 {} MB", state.config.anon_max_file_size_mb),
|
||
));
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn enforce_batch_limits_plan(plan: &Plan, files: &[BatchFileInput]) -> Result<(), AppError> {
|
||
let max_files = plan.max_files_per_batch as usize;
|
||
if files.len() > max_files {
|
||
return Err(AppError::new(
|
||
ErrorCode::InvalidRequest,
|
||
format!("当前套餐单次最多 {} 个文件", plan.max_files_per_batch),
|
||
));
|
||
}
|
||
|
||
let max_bytes = (plan.max_file_size_mb as u64) * 1024 * 1024;
|
||
for f in files {
|
||
if f.original_size > max_bytes {
|
||
return Err(AppError::new(
|
||
ErrorCode::FileTooLarge,
|
||
format!("当前套餐单文件最大 {} MB", plan.max_file_size_mb),
|
||
));
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
async fn ensure_quota_available(
|
||
state: &AppState,
|
||
ctx: &BillingContext,
|
||
needed_units: i32,
|
||
) -> Result<(), AppError> {
|
||
if needed_units <= 0 {
|
||
return Ok(());
|
||
}
|
||
|
||
#[derive(Debug, FromRow)]
|
||
struct UsageRow {
|
||
used_units: i32,
|
||
bonus_units: i32,
|
||
}
|
||
|
||
let usage = sqlx::query_as::<_, UsageRow>(
|
||
r#"
|
||
SELECT used_units, bonus_units
|
||
FROM usage_periods
|
||
WHERE user_id = $1 AND period_start = $2 AND period_end = $3
|
||
"#,
|
||
)
|
||
.bind(ctx.user_id)
|
||
.bind(ctx.period_start)
|
||
.bind(ctx.period_end)
|
||
.fetch_optional(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询用量失败").with_source(err))?
|
||
.unwrap_or(UsageRow {
|
||
used_units: 0,
|
||
bonus_units: 0,
|
||
});
|
||
|
||
let total_units = ctx.plan.included_units_per_period + usage.bonus_units;
|
||
let remaining = total_units - usage.used_units;
|
||
if remaining < needed_units {
|
||
return Err(AppError::new(ErrorCode::QuotaExceeded, "当期配额已用完"));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
async fn anonymous_remaining_units(
|
||
state: &AppState,
|
||
session_id: &str,
|
||
ip: IpAddr,
|
||
) -> Result<i64, AppError> {
|
||
let date = utc8_date();
|
||
let session_key = format!("anon_quota:{session_id}:{date}");
|
||
let ip_key = format!("anon_quota_ip:{ip}:{date}");
|
||
|
||
let mut conn = state.redis.clone();
|
||
let v1: Option<i64> = redis::cmd("GET")
|
||
.arg(session_key)
|
||
.query_async(&mut conn)
|
||
.await
|
||
.unwrap_or(None);
|
||
let v2: Option<i64> = redis::cmd("GET")
|
||
.arg(ip_key)
|
||
.query_async(&mut conn)
|
||
.await
|
||
.unwrap_or(None);
|
||
|
||
let limit = state.config.anon_daily_units as i64;
|
||
Ok(std::cmp::min(limit - v1.unwrap_or(0), limit - v2.unwrap_or(0)))
|
||
}
|
||
|
||
fn utc8_date() -> String {
|
||
let now = Utc::now() + Duration::hours(8);
|
||
now.format("%Y-%m-%d").to_string()
|
||
}
|
||
|
||
#[derive(Debug, FromRow)]
|
||
struct TaskRow {
|
||
id: Uuid,
|
||
status: String,
|
||
total_files: i32,
|
||
completed_files: i32,
|
||
failed_files: i32,
|
||
created_at: DateTime<Utc>,
|
||
completed_at: Option<DateTime<Utc>>,
|
||
expires_at: DateTime<Utc>,
|
||
user_id: Option<Uuid>,
|
||
session_id: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, FromRow)]
|
||
struct TaskFileRow {
|
||
id: Uuid,
|
||
original_name: String,
|
||
original_size: i64,
|
||
compressed_size: Option<i64>,
|
||
saved_percent: Option<f64>,
|
||
status: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct TaskFileView {
|
||
file_id: Uuid,
|
||
original_name: String,
|
||
original_size: i64,
|
||
compressed_size: Option<i64>,
|
||
saved_percent: Option<f64>,
|
||
status: String,
|
||
download_url: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct TaskView {
|
||
task_id: Uuid,
|
||
status: String,
|
||
progress: i32,
|
||
total_files: i32,
|
||
completed_files: i32,
|
||
failed_files: i32,
|
||
files: Vec<TaskFileView>,
|
||
download_all_url: String,
|
||
created_at: DateTime<Utc>,
|
||
completed_at: Option<DateTime<Utc>>,
|
||
expires_at: DateTime<Utc>,
|
||
}
|
||
|
||
async fn get_task(
|
||
State(state): State<AppState>,
|
||
jar: axum_extra::extract::cookie::CookieJar,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
headers: HeaderMap,
|
||
Path(task_id): Path<Uuid>,
|
||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<TaskView>>), AppError> {
|
||
let ip = context::client_ip(&headers, addr.ip());
|
||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||
|
||
let task = sqlx::query_as::<_, TaskRow>(
|
||
r#"
|
||
SELECT
|
||
id,
|
||
status::text AS status,
|
||
total_files,
|
||
completed_files,
|
||
failed_files,
|
||
created_at,
|
||
completed_at,
|
||
expires_at,
|
||
user_id,
|
||
session_id
|
||
FROM tasks
|
||
WHERE id = $1
|
||
"#,
|
||
)
|
||
.bind(task_id)
|
||
.fetch_optional(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||
|
||
if task.expires_at <= Utc::now() {
|
||
return Err(AppError::new(ErrorCode::NotFound, "任务已过期或不存在"));
|
||
}
|
||
|
||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||
|
||
let files = sqlx::query_as::<_, TaskFileRow>(
|
||
r#"
|
||
SELECT
|
||
id,
|
||
original_name,
|
||
original_size,
|
||
compressed_size,
|
||
saved_percent::float8 AS saved_percent,
|
||
status::text AS status
|
||
FROM task_files
|
||
WHERE task_id = $1
|
||
ORDER BY created_at ASC
|
||
"#,
|
||
)
|
||
.bind(task_id)
|
||
.fetch_all(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务文件失败").with_source(err))?;
|
||
|
||
let file_views = files
|
||
.into_iter()
|
||
.map(|f| TaskFileView {
|
||
file_id: f.id,
|
||
original_name: f.original_name,
|
||
original_size: f.original_size,
|
||
compressed_size: f.compressed_size,
|
||
saved_percent: f.saved_percent,
|
||
status: f.status.clone(),
|
||
download_url: if f.status == "completed" {
|
||
Some(format!("/downloads/{}", f.id))
|
||
} else {
|
||
None
|
||
},
|
||
})
|
||
.collect::<Vec<_>>();
|
||
|
||
let processed = task.completed_files + task.failed_files;
|
||
let progress = if task.total_files <= 0 {
|
||
0
|
||
} else {
|
||
((processed as f64) * 100.0 / (task.total_files as f64))
|
||
.round()
|
||
.clamp(0.0, 100.0) as i32
|
||
};
|
||
|
||
Ok((
|
||
jar,
|
||
Json(Envelope {
|
||
success: true,
|
||
data: TaskView {
|
||
task_id,
|
||
status: task.status,
|
||
progress,
|
||
total_files: task.total_files,
|
||
completed_files: task.completed_files,
|
||
failed_files: task.failed_files,
|
||
files: file_views,
|
||
download_all_url: format!("/downloads/tasks/{task_id}"),
|
||
created_at: task.created_at,
|
||
completed_at: task.completed_at,
|
||
expires_at: task.expires_at,
|
||
},
|
||
}),
|
||
))
|
||
}
|
||
|
||
async fn cancel_task(
|
||
State(state): State<AppState>,
|
||
jar: axum_extra::extract::cookie::CookieJar,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
headers: HeaderMap,
|
||
Path(task_id): Path<Uuid>,
|
||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
|
||
let ip = context::client_ip(&headers, addr.ip());
|
||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||
|
||
let task = sqlx::query_as::<_, TaskRow>(
|
||
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
|
||
)
|
||
.bind(task_id)
|
||
.fetch_optional(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||
|
||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||
|
||
if matches!(task.status.as_str(), "completed" | "failed" | "cancelled") {
|
||
return Ok((
|
||
jar,
|
||
Json(Envelope {
|
||
success: true,
|
||
data: serde_json::json!({ "message": "任务已结束" }),
|
||
}),
|
||
));
|
||
}
|
||
|
||
let updated = sqlx::query(
|
||
"UPDATE tasks SET status = 'cancelled', completed_at = NOW() WHERE id = $1 AND status IN ('pending', 'processing')",
|
||
)
|
||
.bind(task_id)
|
||
.execute(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "取消任务失败").with_source(err))?;
|
||
|
||
if updated.rows_affected() == 0 {
|
||
return Err(AppError::new(ErrorCode::InvalidRequest, "任务状态不可取消"));
|
||
}
|
||
|
||
Ok((
|
||
jar,
|
||
Json(Envelope {
|
||
success: true,
|
||
data: serde_json::json!({ "message": "已取消" }),
|
||
}),
|
||
))
|
||
}
|
||
|
||
async fn delete_task(
|
||
State(state): State<AppState>,
|
||
jar: axum_extra::extract::cookie::CookieJar,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
headers: HeaderMap,
|
||
Path(task_id): Path<Uuid>,
|
||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
|
||
let ip = context::client_ip(&headers, addr.ip());
|
||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||
|
||
let task = sqlx::query_as::<_, TaskRow>(
|
||
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
|
||
)
|
||
.bind(task_id)
|
||
.fetch_optional(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||
|
||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||
|
||
if task.status == "processing" {
|
||
return Err(AppError::new(
|
||
ErrorCode::InvalidRequest,
|
||
"任务处理中,请先取消后再删除",
|
||
));
|
||
}
|
||
|
||
let paths: Vec<Option<String>> =
|
||
sqlx::query_scalar("SELECT storage_path FROM task_files WHERE task_id = $1")
|
||
.bind(task_id)
|
||
.fetch_all(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询文件失败").with_source(err))?;
|
||
|
||
for p in paths.into_iter().flatten() {
|
||
let _ = tokio::fs::remove_file(p).await;
|
||
}
|
||
|
||
if state.config.storage_type.to_ascii_lowercase() == "local" {
|
||
let zip_path = format!("{}/zips/{task_id}.zip", state.config.storage_path);
|
||
let _ = tokio::fs::remove_file(zip_path).await;
|
||
let orig_dir = format!("{}/orig/{task_id}", state.config.storage_path);
|
||
let _ = tokio::fs::remove_dir_all(orig_dir).await;
|
||
}
|
||
|
||
let deleted = sqlx::query("DELETE FROM tasks WHERE id = $1")
|
||
.bind(task_id)
|
||
.execute(&state.db)
|
||
.await
|
||
.map_err(|err| AppError::new(ErrorCode::Internal, "删除任务失败").with_source(err))?;
|
||
|
||
if deleted.rows_affected() == 0 {
|
||
return Err(AppError::new(ErrorCode::NotFound, "任务不存在"));
|
||
}
|
||
|
||
Ok((
|
||
jar,
|
||
Json(Envelope {
|
||
success: true,
|
||
data: serde_json::json!({ "message": "已删除" }),
|
||
}),
|
||
))
|
||
}
|
||
|
||
fn authorize_task(
|
||
principal: &context::Principal,
|
||
user_id: Option<Uuid>,
|
||
session_id: &str,
|
||
) -> Result<(), AppError> {
|
||
if let Some(owner) = user_id {
|
||
match principal {
|
||
context::Principal::User { user_id: me, .. } if *me == owner => Ok(()),
|
||
context::Principal::ApiKey { user_id: me, .. } if *me == owner => Ok(()),
|
||
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
|
||
}
|
||
} else {
|
||
match principal {
|
||
context::Principal::Anonymous { session_id: sid } if sid == session_id => Ok(()),
|
||
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn cleanup_file_paths(files: &[BatchFileInput]) {
|
||
for f in files {
|
||
let _ = tokio::fs::remove_file(&f.storage_path).await;
|
||
}
|
||
}
|