Implement compression quota refunds and admin manual subscription
This commit is contained in:
987
src/api/tasks.rs
Normal file
987
src/api/tasks.rs
Normal file
@@ -0,0 +1,987 @@
|
||||
use crate::api::context;
|
||||
use crate::api::envelope::Envelope;
|
||||
use crate::error::{AppError, ErrorCode};
|
||||
use crate::services::billing;
|
||||
use crate::services::billing::{BillingContext, Plan};
|
||||
use crate::services::compress;
|
||||
use crate::services::compress::{CompressionLevel, ImageFmt};
|
||||
use crate::services::idempotency;
|
||||
use crate::state::AppState;
|
||||
|
||||
use axum::extract::{ConnectInfo, Multipart, Path, State};
|
||||
use axum::http::HeaderMap;
|
||||
use axum::routing::{delete, get, post};
|
||||
use axum::{Json, Router};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::FromRow;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/compress/batch", post(create_batch_task))
|
||||
.route("/compress/tasks/{task_id}", get(get_task))
|
||||
.route("/compress/tasks/{task_id}/cancel", post(cancel_task))
|
||||
.route("/compress/tasks/{task_id}", delete(delete_task))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct BatchCreateResponse {
|
||||
task_id: Uuid,
|
||||
total_files: i32,
|
||||
status: String,
|
||||
status_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BatchFileInput {
|
||||
file_id: Uuid,
|
||||
original_name: String,
|
||||
original_format: ImageFmt,
|
||||
output_format: ImageFmt,
|
||||
original_size: u64,
|
||||
storage_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BatchOptions {
|
||||
level: CompressionLevel,
|
||||
compression_rate: Option<u8>,
|
||||
output_format: Option<ImageFmt>,
|
||||
max_width: Option<u32>,
|
||||
max_height: Option<u32>,
|
||||
preserve_metadata: bool,
|
||||
}
|
||||
|
||||
async fn create_batch_task(
|
||||
State(state): State<AppState>,
|
||||
jar: axum_extra::extract::cookie::CookieJar,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<BatchCreateResponse>>), AppError> {
|
||||
let ip = context::client_ip(&headers, addr.ip());
|
||||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||||
|
||||
if state.config.storage_type.to_ascii_lowercase() != "local" {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::StorageUnavailable,
|
||||
"当前仅支持本地存储(STORAGE_TYPE=local)",
|
||||
));
|
||||
}
|
||||
|
||||
let idempotency_key = headers
|
||||
.get("idempotency-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(str::to_string);
|
||||
let idempotency_scope = match &principal {
|
||||
context::Principal::User { user_id, .. } => Some(idempotency::Scope::User(*user_id)),
|
||||
context::Principal::ApiKey { api_key_id, .. } => Some(idempotency::Scope::ApiKey(*api_key_id)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let task_id = Uuid::new_v4();
|
||||
let (files, opts, request_hash) = parse_batch_request(&state, task_id, &mut multipart).await?;
|
||||
|
||||
if files.is_empty() {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(AppError::new(ErrorCode::InvalidRequest, "缺少 files[]"));
|
||||
}
|
||||
|
||||
let mut idem_acquired = false;
|
||||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||||
match idempotency::begin(
|
||||
&state,
|
||||
scope,
|
||||
idem_key,
|
||||
&request_hash,
|
||||
state.config.idempotency_ttl_hours as i64,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
idempotency::BeginResult::Replay { response_body, .. } => {
|
||||
cleanup_file_paths(&files).await;
|
||||
let resp: BatchCreateResponse =
|
||||
serde_json::from_value(response_body).map_err(|err| {
|
||||
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
|
||||
})?;
|
||||
return Ok((jar, Json(Envelope { success: true, data: resp })));
|
||||
}
|
||||
idempotency::BeginResult::InProgress => {
|
||||
cleanup_file_paths(&files).await;
|
||||
if let Some((_status, body)) = idempotency::wait_for_replay(
|
||||
&state,
|
||||
scope,
|
||||
idem_key,
|
||||
&request_hash,
|
||||
10_000,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
let resp: BatchCreateResponse =
|
||||
serde_json::from_value(body).map_err(|err| {
|
||||
AppError::new(ErrorCode::Internal, "幂等结果解析失败").with_source(err)
|
||||
})?;
|
||||
return Ok((jar, Json(Envelope { success: true, data: resp })));
|
||||
}
|
||||
return Err(AppError::new(
|
||||
ErrorCode::InvalidRequest,
|
||||
"请求正在处理中,请稍后重试",
|
||||
));
|
||||
}
|
||||
idempotency::BeginResult::Acquired { .. } => {
|
||||
idem_acquired = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let create_result: Result<BatchCreateResponse, AppError> = (async {
|
||||
let (retention, task_owner, source) = match &principal {
|
||||
context::Principal::Anonymous { session_id } => {
|
||||
enforce_batch_limits_anonymous(&state, &files)?;
|
||||
let remaining = anonymous_remaining_units(&state, session_id, ip).await?;
|
||||
if remaining < files.len() as i64 {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::QuotaExceeded,
|
||||
"匿名试用次数已用完(每日 10 次)",
|
||||
));
|
||||
}
|
||||
Ok((
|
||||
Duration::hours(state.config.anon_retention_hours as i64),
|
||||
TaskOwner::Anonymous {
|
||||
session_id: session_id.clone(),
|
||||
},
|
||||
"web",
|
||||
))
|
||||
}
|
||||
context::Principal::User {
|
||||
user_id,
|
||||
email_verified,
|
||||
..
|
||||
} => {
|
||||
if !email_verified {
|
||||
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
|
||||
}
|
||||
let billing = billing::get_user_billing(&state, *user_id).await?;
|
||||
enforce_batch_limits_plan(&billing.plan, &files)?;
|
||||
ensure_quota_available(&state, &billing, files.len() as i32).await?;
|
||||
Ok((
|
||||
Duration::days(billing.plan.retention_days as i64),
|
||||
TaskOwner::User { user_id: *user_id },
|
||||
"web",
|
||||
))
|
||||
}
|
||||
context::Principal::ApiKey {
|
||||
user_id,
|
||||
api_key_id,
|
||||
email_verified,
|
||||
..
|
||||
} => {
|
||||
if !email_verified {
|
||||
return Err(AppError::new(ErrorCode::EmailNotVerified, "请先验证邮箱"));
|
||||
}
|
||||
let billing = billing::get_user_billing(&state, *user_id).await?;
|
||||
if !billing.plan.feature_api_enabled {
|
||||
return Err(AppError::new(ErrorCode::Forbidden, "当前套餐未开通 API"));
|
||||
}
|
||||
enforce_batch_limits_plan(&billing.plan, &files)?;
|
||||
ensure_quota_available(&state, &billing, files.len() as i32).await?;
|
||||
Ok((
|
||||
Duration::days(billing.plan.retention_days as i64),
|
||||
TaskOwner::ApiKey {
|
||||
user_id: *user_id,
|
||||
api_key_id: *api_key_id,
|
||||
},
|
||||
"api",
|
||||
))
|
||||
}
|
||||
}?;
|
||||
|
||||
tokio::fs::create_dir_all(&state.config.storage_path)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err)
|
||||
})?;
|
||||
|
||||
let expires_at = Utc::now() + retention;
|
||||
let (user_id, session_id, api_key_id) = match &task_owner {
|
||||
TaskOwner::Anonymous { session_id } => (None, Some(session_id.clone()), None),
|
||||
TaskOwner::User { user_id } => (Some(*user_id), None, None),
|
||||
TaskOwner::ApiKey { user_id, api_key_id } => (Some(*user_id), None, Some(*api_key_id)),
|
||||
};
|
||||
|
||||
let total_original_size: i64 = files.iter().map(|f| f.original_size as i64).sum();
|
||||
|
||||
let mut tx = state
|
||||
.db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "开启事务失败").with_source(err))?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO tasks (
|
||||
id, user_id, session_id, api_key_id, client_ip, source, status,
|
||||
compression_level, output_format, max_width, max_height, preserve_metadata,
|
||||
compression_rate,
|
||||
total_files, completed_files, failed_files,
|
||||
total_original_size, total_compressed_size,
|
||||
expires_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5::inet, $6::task_source, 'pending',
|
||||
$7::compression_level, $8, $9, $10, $11, $12,
|
||||
$13, 0, 0,
|
||||
$14, 0,
|
||||
$15
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(task_id)
|
||||
.bind(user_id)
|
||||
.bind(session_id)
|
||||
.bind(api_key_id)
|
||||
.bind(ip.to_string())
|
||||
.bind(source)
|
||||
.bind(opts.level.as_str())
|
||||
.bind(opts.output_format.map(|f| f.as_str()))
|
||||
.bind(opts.max_width.map(|v| v as i32))
|
||||
.bind(opts.max_height.map(|v| v as i32))
|
||||
.bind(false)
|
||||
.bind(opts.compression_rate.map(|v| v as i16))
|
||||
.bind(files.len() as i32)
|
||||
.bind(total_original_size)
|
||||
.bind(expires_at)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "创建任务失败").with_source(err))?;
|
||||
|
||||
for file in &files {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO task_files (
|
||||
id, task_id,
|
||||
original_name, original_format, output_format,
|
||||
original_size,
|
||||
storage_path, status
|
||||
) VALUES (
|
||||
$1, $2,
|
||||
$3, $4, $5,
|
||||
$6,
|
||||
$7, 'pending'
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(file.file_id)
|
||||
.bind(task_id)
|
||||
.bind(&file.original_name)
|
||||
.bind(file.original_format.as_str())
|
||||
.bind(file.output_format.as_str())
|
||||
.bind(file.original_size as i64)
|
||||
.bind(&file.storage_path)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "创建文件记录失败").with_source(err))?;
|
||||
}
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "提交事务失败").with_source(err))?;
|
||||
|
||||
if let Err(err) = enqueue_task(&state, task_id).await {
|
||||
let _ = sqlx::query("UPDATE tasks SET status = 'failed', error_message = $2 WHERE id = $1")
|
||||
.bind(task_id)
|
||||
.bind("队列提交失败")
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(BatchCreateResponse {
|
||||
task_id,
|
||||
total_files: files.len() as i32,
|
||||
status: "pending".to_string(),
|
||||
status_url: format!("/api/v1/compress/tasks/{task_id}"),
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
match create_result {
|
||||
Ok(resp) => {
|
||||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||||
if idem_acquired {
|
||||
let _ = idempotency::complete(
|
||||
&state,
|
||||
scope,
|
||||
idem_key,
|
||||
&request_hash,
|
||||
200,
|
||||
serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Ok((jar, Json(Envelope { success: true, data: resp })))
|
||||
}
|
||||
Err(err) => {
|
||||
if let (Some(scope), Some(idem_key)) = (idempotency_scope, idempotency_key.as_deref()) {
|
||||
if idem_acquired {
|
||||
let _ = idempotency::abort(&state, scope, idem_key, &request_hash).await;
|
||||
}
|
||||
}
|
||||
cleanup_file_paths(&files).await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TaskOwner {
|
||||
Anonymous { session_id: String },
|
||||
User { user_id: Uuid },
|
||||
ApiKey { user_id: Uuid, api_key_id: Uuid },
|
||||
}
|
||||
|
||||
async fn enqueue_task(state: &AppState, task_id: Uuid) -> Result<(), AppError> {
|
||||
let mut conn = state.redis.clone();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
redis::cmd("XADD")
|
||||
.arg("stream:compress_jobs")
|
||||
.arg("*")
|
||||
.arg("task_id")
|
||||
.arg(task_id.to_string())
|
||||
.arg("created_at")
|
||||
.arg(now)
|
||||
.query_async::<_, redis::Value>(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "写入队列失败").with_source(err))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn parse_batch_request(
|
||||
state: &AppState,
|
||||
task_id: Uuid,
|
||||
multipart: &mut Multipart,
|
||||
) -> Result<(Vec<BatchFileInput>, BatchOptions, String), AppError> {
|
||||
let mut files: Vec<BatchFileInput> = Vec::new();
|
||||
let mut file_digests: Vec<String> = Vec::new();
|
||||
let mut opts = BatchOptions {
|
||||
level: CompressionLevel::Medium,
|
||||
compression_rate: None,
|
||||
output_format: None,
|
||||
max_width: None,
|
||||
max_height: None,
|
||||
preserve_metadata: false,
|
||||
};
|
||||
|
||||
let base_dir = format!("{}/orig/{task_id}", state.config.storage_path);
|
||||
tokio::fs::create_dir_all(&base_dir)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::StorageUnavailable, "创建存储目录失败").with_source(err))?;
|
||||
|
||||
loop {
|
||||
let next = multipart.next_field().await.map_err(|err| {
|
||||
AppError::new(ErrorCode::InvalidRequest, "读取上传内容失败").with_source(err)
|
||||
});
|
||||
|
||||
let field = match next {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let Some(field) = field else { break };
|
||||
|
||||
let name = field.name().unwrap_or("").to_string();
|
||||
if name == "files" || name == "files[]" {
|
||||
let file_id = Uuid::new_v4();
|
||||
let original_name = field.file_name().unwrap_or("upload").to_string();
|
||||
let bytes = match field.bytes().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(
|
||||
AppError::new(ErrorCode::InvalidRequest, "读取文件失败").with_source(err)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let original_size = bytes.len() as u64;
|
||||
let file_digest = {
|
||||
let mut h = Sha256::new();
|
||||
h.update(&bytes);
|
||||
h.update(original_name.as_bytes());
|
||||
hex::encode(h.finalize())
|
||||
};
|
||||
let original_format = match compress::detect_format(&bytes) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let output_format = opts.output_format.unwrap_or(original_format);
|
||||
let path = format!("{base_dir}/{file_id}.{}", original_format.extension());
|
||||
|
||||
let mut f = match tokio::fs::File::create(&path).await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(
|
||||
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
|
||||
);
|
||||
}
|
||||
};
|
||||
if let Err(err) = f.write_all(&bytes).await {
|
||||
let _ = tokio::fs::remove_file(&path).await;
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(
|
||||
AppError::new(ErrorCode::StorageUnavailable, "写入文件失败").with_source(err),
|
||||
);
|
||||
}
|
||||
|
||||
files.push(BatchFileInput {
|
||||
file_id,
|
||||
original_name,
|
||||
original_format,
|
||||
output_format,
|
||||
original_size,
|
||||
storage_path: path,
|
||||
});
|
||||
file_digests.push(file_digest);
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = match field.text().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(
|
||||
AppError::new(ErrorCode::InvalidRequest, "读取字段失败").with_source(err),
|
||||
);
|
||||
}
|
||||
};
|
||||
match name.as_str() {
|
||||
"level" => {
|
||||
opts.level = match compress::parse_level(&text) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
"output_format" => {
|
||||
let v = text.trim();
|
||||
if !v.is_empty() {
|
||||
opts.output_format = Some(match compress::parse_output_format(v) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
"compression_rate" | "quality" => {
|
||||
let v = text.trim();
|
||||
if !v.is_empty() {
|
||||
opts.compression_rate = Some(match compress::parse_compression_rate(v) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
"max_width" => {
|
||||
let v = text.trim();
|
||||
if !v.is_empty() {
|
||||
opts.max_width = Some(match v.parse::<u32>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(AppError::new(ErrorCode::InvalidRequest, "max_width 格式错误"));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
"max_height" => {
|
||||
let v = text.trim();
|
||||
if !v.is_empty() {
|
||||
opts.max_height = Some(match v.parse::<u32>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(AppError::new(ErrorCode::InvalidRequest, "max_height 格式错误"));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
"preserve_metadata" => {
|
||||
opts.preserve_metadata = matches!(
|
||||
text.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "y" | "on"
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rate) = opts.compression_rate {
|
||||
opts.level = compress::rate_to_level(rate);
|
||||
}
|
||||
|
||||
if opts.output_format.is_some() {
|
||||
cleanup_file_paths(&files).await;
|
||||
return Err(AppError::new(
|
||||
ErrorCode::InvalidRequest,
|
||||
"当前仅支持保持原图片格式",
|
||||
));
|
||||
}
|
||||
|
||||
let mw = opts.max_width.map(|v| v.to_string()).unwrap_or_default();
|
||||
let mh = opts.max_height.map(|v| v.to_string()).unwrap_or_default();
|
||||
let out_fmt = opts.output_format.map(|f| f.as_str()).unwrap_or("");
|
||||
let rate_key = opts
|
||||
.compression_rate
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_default();
|
||||
let preserve = if opts.preserve_metadata { "1" } else { "0" };
|
||||
|
||||
let mut h = Sha256::new();
|
||||
h.update(b"compress_batch_v1");
|
||||
h.update(opts.level.as_str().as_bytes());
|
||||
h.update(out_fmt.as_bytes());
|
||||
h.update(rate_key.as_bytes());
|
||||
h.update(mw.as_bytes());
|
||||
h.update(mh.as_bytes());
|
||||
h.update(preserve.as_bytes());
|
||||
for d in &file_digests {
|
||||
h.update(d.as_bytes());
|
||||
}
|
||||
let request_hash = hex::encode(h.finalize());
|
||||
|
||||
Ok((files, opts, request_hash))
|
||||
}
|
||||
|
||||
fn enforce_batch_limits_anonymous(state: &AppState, files: &[BatchFileInput]) -> Result<(), AppError> {
|
||||
let max_files = state.config.anon_max_files_per_batch as usize;
|
||||
if files.len() > max_files {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::InvalidRequest,
|
||||
format!("匿名试用单次最多 {} 个文件", max_files),
|
||||
));
|
||||
}
|
||||
|
||||
let max_bytes = state.config.anon_max_file_size_mb * 1024 * 1024;
|
||||
for f in files {
|
||||
if f.original_size > max_bytes {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::FileTooLarge,
|
||||
format!("匿名试用单文件最大 {} MB", state.config.anon_max_file_size_mb),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn enforce_batch_limits_plan(plan: &Plan, files: &[BatchFileInput]) -> Result<(), AppError> {
|
||||
let max_files = plan.max_files_per_batch as usize;
|
||||
if files.len() > max_files {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::InvalidRequest,
|
||||
format!("当前套餐单次最多 {} 个文件", plan.max_files_per_batch),
|
||||
));
|
||||
}
|
||||
|
||||
let max_bytes = (plan.max_file_size_mb as u64) * 1024 * 1024;
|
||||
for f in files {
|
||||
if f.original_size > max_bytes {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::FileTooLarge,
|
||||
format!("当前套餐单文件最大 {} MB", plan.max_file_size_mb),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_quota_available(
|
||||
state: &AppState,
|
||||
ctx: &BillingContext,
|
||||
needed_units: i32,
|
||||
) -> Result<(), AppError> {
|
||||
if needed_units <= 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct UsageRow {
|
||||
used_units: i32,
|
||||
bonus_units: i32,
|
||||
}
|
||||
|
||||
let usage = sqlx::query_as::<_, UsageRow>(
|
||||
r#"
|
||||
SELECT used_units, bonus_units
|
||||
FROM usage_periods
|
||||
WHERE user_id = $1 AND period_start = $2 AND period_end = $3
|
||||
"#,
|
||||
)
|
||||
.bind(ctx.user_id)
|
||||
.bind(ctx.period_start)
|
||||
.bind(ctx.period_end)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询用量失败").with_source(err))?
|
||||
.unwrap_or(UsageRow {
|
||||
used_units: 0,
|
||||
bonus_units: 0,
|
||||
});
|
||||
|
||||
let total_units = ctx.plan.included_units_per_period + usage.bonus_units;
|
||||
let remaining = total_units - usage.used_units;
|
||||
if remaining < needed_units {
|
||||
return Err(AppError::new(ErrorCode::QuotaExceeded, "当期配额已用完"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn anonymous_remaining_units(
|
||||
state: &AppState,
|
||||
session_id: &str,
|
||||
ip: IpAddr,
|
||||
) -> Result<i64, AppError> {
|
||||
let date = utc8_date();
|
||||
let session_key = format!("anon_quota:{session_id}:{date}");
|
||||
let ip_key = format!("anon_quota_ip:{ip}:{date}");
|
||||
|
||||
let mut conn = state.redis.clone();
|
||||
let v1: Option<i64> = redis::cmd("GET")
|
||||
.arg(session_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
let v2: Option<i64> = redis::cmd("GET")
|
||||
.arg(ip_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let limit = state.config.anon_daily_units as i64;
|
||||
Ok(std::cmp::min(limit - v1.unwrap_or(0), limit - v2.unwrap_or(0)))
|
||||
}
|
||||
|
||||
fn utc8_date() -> String {
|
||||
let now = Utc::now() + Duration::hours(8);
|
||||
now.format("%Y-%m-%d").to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct TaskRow {
|
||||
id: Uuid,
|
||||
status: String,
|
||||
total_files: i32,
|
||||
completed_files: i32,
|
||||
failed_files: i32,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
expires_at: DateTime<Utc>,
|
||||
user_id: Option<Uuid>,
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct TaskFileRow {
|
||||
id: Uuid,
|
||||
original_name: String,
|
||||
original_size: i64,
|
||||
compressed_size: Option<i64>,
|
||||
saved_percent: Option<f64>,
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TaskFileView {
|
||||
file_id: Uuid,
|
||||
original_name: String,
|
||||
original_size: i64,
|
||||
compressed_size: Option<i64>,
|
||||
saved_percent: Option<f64>,
|
||||
status: String,
|
||||
download_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TaskView {
|
||||
task_id: Uuid,
|
||||
status: String,
|
||||
progress: i32,
|
||||
total_files: i32,
|
||||
completed_files: i32,
|
||||
failed_files: i32,
|
||||
files: Vec<TaskFileView>,
|
||||
download_all_url: String,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
async fn get_task(
|
||||
State(state): State<AppState>,
|
||||
jar: axum_extra::extract::cookie::CookieJar,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
Path(task_id): Path<Uuid>,
|
||||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<TaskView>>), AppError> {
|
||||
let ip = context::client_ip(&headers, addr.ip());
|
||||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||||
|
||||
let task = sqlx::query_as::<_, TaskRow>(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
status::text AS status,
|
||||
total_files,
|
||||
completed_files,
|
||||
failed_files,
|
||||
created_at,
|
||||
completed_at,
|
||||
expires_at,
|
||||
user_id,
|
||||
session_id
|
||||
FROM tasks
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||||
|
||||
if task.expires_at <= Utc::now() {
|
||||
return Err(AppError::new(ErrorCode::NotFound, "任务已过期或不存在"));
|
||||
}
|
||||
|
||||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||||
|
||||
let files = sqlx::query_as::<_, TaskFileRow>(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
original_name,
|
||||
original_size,
|
||||
compressed_size,
|
||||
saved_percent::float8 AS saved_percent,
|
||||
status::text AS status
|
||||
FROM task_files
|
||||
WHERE task_id = $1
|
||||
ORDER BY created_at ASC
|
||||
"#,
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务文件失败").with_source(err))?;
|
||||
|
||||
let file_views = files
|
||||
.into_iter()
|
||||
.map(|f| TaskFileView {
|
||||
file_id: f.id,
|
||||
original_name: f.original_name,
|
||||
original_size: f.original_size,
|
||||
compressed_size: f.compressed_size,
|
||||
saved_percent: f.saved_percent,
|
||||
status: f.status.clone(),
|
||||
download_url: if f.status == "completed" {
|
||||
Some(format!("/downloads/{}", f.id))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let processed = task.completed_files + task.failed_files;
|
||||
let progress = if task.total_files <= 0 {
|
||||
0
|
||||
} else {
|
||||
((processed as f64) * 100.0 / (task.total_files as f64))
|
||||
.round()
|
||||
.clamp(0.0, 100.0) as i32
|
||||
};
|
||||
|
||||
Ok((
|
||||
jar,
|
||||
Json(Envelope {
|
||||
success: true,
|
||||
data: TaskView {
|
||||
task_id,
|
||||
status: task.status,
|
||||
progress,
|
||||
total_files: task.total_files,
|
||||
completed_files: task.completed_files,
|
||||
failed_files: task.failed_files,
|
||||
files: file_views,
|
||||
download_all_url: format!("/downloads/tasks/{task_id}"),
|
||||
created_at: task.created_at,
|
||||
completed_at: task.completed_at,
|
||||
expires_at: task.expires_at,
|
||||
},
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
async fn cancel_task(
|
||||
State(state): State<AppState>,
|
||||
jar: axum_extra::extract::cookie::CookieJar,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
Path(task_id): Path<Uuid>,
|
||||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
|
||||
let ip = context::client_ip(&headers, addr.ip());
|
||||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||||
|
||||
let task = sqlx::query_as::<_, TaskRow>(
|
||||
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||||
|
||||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||||
|
||||
if matches!(task.status.as_str(), "completed" | "failed" | "cancelled") {
|
||||
return Ok((
|
||||
jar,
|
||||
Json(Envelope {
|
||||
success: true,
|
||||
data: serde_json::json!({ "message": "任务已结束" }),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let updated = sqlx::query(
|
||||
"UPDATE tasks SET status = 'cancelled', completed_at = NOW() WHERE id = $1 AND status IN ('pending', 'processing')",
|
||||
)
|
||||
.bind(task_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "取消任务失败").with_source(err))?;
|
||||
|
||||
if updated.rows_affected() == 0 {
|
||||
return Err(AppError::new(ErrorCode::InvalidRequest, "任务状态不可取消"));
|
||||
}
|
||||
|
||||
Ok((
|
||||
jar,
|
||||
Json(Envelope {
|
||||
success: true,
|
||||
data: serde_json::json!({ "message": "已取消" }),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
async fn delete_task(
|
||||
State(state): State<AppState>,
|
||||
jar: axum_extra::extract::cookie::CookieJar,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
Path(task_id): Path<Uuid>,
|
||||
) -> Result<(axum_extra::extract::cookie::CookieJar, Json<Envelope<serde_json::Value>>), AppError> {
|
||||
let ip = context::client_ip(&headers, addr.ip());
|
||||
let (jar, principal) = context::authenticate(&state, jar, &headers, ip).await?;
|
||||
|
||||
let task = sqlx::query_as::<_, TaskRow>(
|
||||
"SELECT id, status::text AS status, total_files, completed_files, failed_files, created_at, completed_at, expires_at, user_id, session_id FROM tasks WHERE id = $1",
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询任务失败").with_source(err))?
|
||||
.ok_or_else(|| AppError::new(ErrorCode::NotFound, "任务不存在"))?;
|
||||
|
||||
authorize_task(&principal, task.user_id, task.session_id.as_deref().unwrap_or(""))?;
|
||||
|
||||
if task.status == "processing" {
|
||||
return Err(AppError::new(
|
||||
ErrorCode::InvalidRequest,
|
||||
"任务处理中,请先取消后再删除",
|
||||
));
|
||||
}
|
||||
|
||||
let paths: Vec<Option<String>> =
|
||||
sqlx::query_scalar("SELECT storage_path FROM task_files WHERE task_id = $1")
|
||||
.bind(task_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "查询文件失败").with_source(err))?;
|
||||
|
||||
for p in paths.into_iter().flatten() {
|
||||
let _ = tokio::fs::remove_file(p).await;
|
||||
}
|
||||
|
||||
if state.config.storage_type.to_ascii_lowercase() == "local" {
|
||||
let zip_path = format!("{}/zips/{task_id}.zip", state.config.storage_path);
|
||||
let _ = tokio::fs::remove_file(zip_path).await;
|
||||
let orig_dir = format!("{}/orig/{task_id}", state.config.storage_path);
|
||||
let _ = tokio::fs::remove_dir_all(orig_dir).await;
|
||||
}
|
||||
|
||||
let deleted = sqlx::query("DELETE FROM tasks WHERE id = $1")
|
||||
.bind(task_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|err| AppError::new(ErrorCode::Internal, "删除任务失败").with_source(err))?;
|
||||
|
||||
if deleted.rows_affected() == 0 {
|
||||
return Err(AppError::new(ErrorCode::NotFound, "任务不存在"));
|
||||
}
|
||||
|
||||
Ok((
|
||||
jar,
|
||||
Json(Envelope {
|
||||
success: true,
|
||||
data: serde_json::json!({ "message": "已删除" }),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
fn authorize_task(
|
||||
principal: &context::Principal,
|
||||
user_id: Option<Uuid>,
|
||||
session_id: &str,
|
||||
) -> Result<(), AppError> {
|
||||
if let Some(owner) = user_id {
|
||||
match principal {
|
||||
context::Principal::User { user_id: me, .. } if *me == owner => Ok(()),
|
||||
context::Principal::ApiKey { user_id: me, .. } if *me == owner => Ok(()),
|
||||
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
|
||||
}
|
||||
} else {
|
||||
match principal {
|
||||
context::Principal::Anonymous { session_id: sid } if sid == session_id => Ok(()),
|
||||
_ => Err(AppError::new(ErrorCode::Forbidden, "无权限访问该任务")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup_file_paths(files: &[BatchFileInput]) {
|
||||
for f in files {
|
||||
let _ = tokio::fs::remove_file(&f.storage_path).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user