260 lines
6.9 KiB
Python
260 lines
6.9 KiB
Python
"""
|
||
文件管理模块测试
|
||
"""
|
||
import pytest
|
||
import os
|
||
from io import BytesIO
|
||
from fastapi.testclient import TestClient
|
||
from PIL import Image
|
||
|
||
|
||
def test_upload_file(client: TestClient, auth_headers: dict):
|
||
"""测试文件上传"""
|
||
# 创建测试图片
|
||
img = Image.new('RGB', (100, 100), color='red')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'JPEG')
|
||
img_io.seek(0)
|
||
|
||
response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("test.jpg", img_io, "image/jpeg")},
|
||
data={"remark": "测试文件"}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["original_name"] == "test.jpg"
|
||
assert data["file_type"] == "image/jpeg"
|
||
assert data["message"] == "上传成功"
|
||
assert "id" in data
|
||
assert "download_url" in data
|
||
|
||
|
||
def test_upload_large_file(client: TestClient, auth_headers: dict):
|
||
"""测试大文件上传(应失败)"""
|
||
# 创建超过限制的文件(11MB)
|
||
large_content = b"x" * (11 * 1024 * 1024)
|
||
|
||
response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("large.jpg", BytesIO(large_content), "image/jpeg")}
|
||
)
|
||
|
||
assert response.status_code == 400
|
||
|
||
|
||
def test_upload_invalid_type(client: TestClient, auth_headers: dict):
|
||
"""测试不支持的文件类型(应失败)"""
|
||
response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("test.exe", BytesIO(b"test"), "application/x-msdownload")}
|
||
)
|
||
|
||
assert response.status_code == 400
|
||
|
||
|
||
def test_get_file_list(client: TestClient, auth_headers: dict):
|
||
"""测试获取文件列表"""
|
||
response = client.get(
|
||
"/api/v1/files/",
|
||
headers=auth_headers,
|
||
params={"page": 1, "page_size": 20}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert isinstance(data, list)
|
||
|
||
|
||
def test_get_file_detail(client: TestClient, auth_headers: dict):
|
||
"""测试获取文件详情"""
|
||
# 先上传一个文件
|
||
img = Image.new('RGB', (100, 100), color='blue')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'PNG')
|
||
img_io.seek(0)
|
||
|
||
upload_response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("test.png", img_io, "image/png")}
|
||
)
|
||
file_id = upload_response.json()["id"]
|
||
|
||
# 获取文件详情
|
||
response = client.get(
|
||
f"/api/v1/files/{file_id}",
|
||
headers=auth_headers
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == file_id
|
||
assert "download_url" in data
|
||
|
||
|
||
def test_get_file_statistics(client: TestClient, auth_headers: dict):
|
||
"""测试获取文件统计"""
|
||
response = client.get(
|
||
"/api/v1/files/statistics",
|
||
headers=auth_headers
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "total_files" in data
|
||
assert "total_size" in data
|
||
assert "type_distribution" in data
|
||
|
||
|
||
def test_create_share_link(client: TestClient, auth_headers: dict):
|
||
"""测试生成分享链接"""
|
||
# 先上传一个文件
|
||
img = Image.new('RGB', (100, 100), color='green')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'JPEG')
|
||
img_io.seek(0)
|
||
|
||
upload_response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("share.jpg", img_io, "image/jpeg")}
|
||
)
|
||
file_id = upload_response.json()["id"]
|
||
|
||
# 生成分享链接
|
||
response = client.post(
|
||
f"/api/v1/files/{file_id}/share",
|
||
headers=auth_headers,
|
||
json={"expire_days": 7}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "share_code" in data
|
||
assert "share_url" in data
|
||
assert "expire_time" in data
|
||
|
||
|
||
def test_delete_file(client: TestClient, auth_headers: dict):
|
||
"""测试删除文件"""
|
||
# 先上传一个文件
|
||
img = Image.new('RGB', (100, 100), color='yellow')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'JPEG')
|
||
img_io.seek(0)
|
||
|
||
upload_response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("delete.jpg", img_io, "image/jpeg")}
|
||
)
|
||
file_id = upload_response.json()["id"]
|
||
|
||
# 删除文件
|
||
response = client.delete(
|
||
f"/api/v1/files/{file_id}",
|
||
headers=auth_headers
|
||
)
|
||
|
||
assert response.status_code == 204
|
||
|
||
|
||
def test_batch_delete_files(client: TestClient, auth_headers: dict):
|
||
"""测试批量删除文件"""
|
||
# 上传多个文件
|
||
file_ids = []
|
||
for i in range(3):
|
||
img = Image.new('RGB', (100, 100), color='red')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'JPEG')
|
||
img_io.seek(0)
|
||
|
||
upload_response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": (f"batch_{i}.jpg", img_io, "image/jpeg")}
|
||
)
|
||
file_ids.append(upload_response.json()["id"])
|
||
|
||
# 批量删除
|
||
response = client.delete(
|
||
"/api/v1/files/batch",
|
||
headers=auth_headers,
|
||
json={"file_ids": file_ids}
|
||
)
|
||
|
||
assert response.status_code == 204
|
||
|
||
|
||
def test_chunk_upload_init(client: TestClient, auth_headers: dict):
|
||
"""测试初始化分片上传"""
|
||
response = client.post(
|
||
"/api/v1/files/chunks/init",
|
||
headers=auth_headers,
|
||
json={
|
||
"file_name": "large_file.zip",
|
||
"file_size": 10485760, # 10MB
|
||
"file_type": "application/zip",
|
||
"total_chunks": 2
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "upload_id" in data
|
||
|
||
|
||
def test_access_shared_file(client: TestClient, auth_headers: dict):
|
||
"""测试访问分享文件"""
|
||
# 先上传一个文件
|
||
img = Image.new('RGB', (100, 100), color='purple')
|
||
img_io = BytesIO()
|
||
img.save(img_io, 'JPEG')
|
||
img_io.seek(0)
|
||
|
||
upload_response = client.post(
|
||
"/api/v1/files/upload",
|
||
headers=auth_headers,
|
||
files={"file": ("shared.jpg", img_io, "image/jpeg")}
|
||
)
|
||
file_id = upload_response.json()["id"]
|
||
|
||
# 生成分享链接
|
||
share_response = client.post(
|
||
f"/api/v1/files/{file_id}/share",
|
||
headers=auth_headers,
|
||
json={"expire_days": 7}
|
||
)
|
||
share_code = share_response.json()["share_code"]
|
||
|
||
# 访问分享文件(无需认证)
|
||
response = client.get(f"/api/v1/files/share/{share_code}")
|
||
|
||
assert response.status_code == 200
|
||
|
||
|
||
# 运行测试的fixtures
|
||
@pytest.fixture
|
||
def auth_headers(client: TestClient):
|
||
"""获取认证头"""
|
||
# 先登录
|
||
response = client.post(
|
||
"/api/v1/auth/login",
|
||
json={
|
||
"username": "admin",
|
||
"password": "admin123"
|
||
}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
token = response.json()["access_token"]
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
# 如果登录失败,使用测试token
|
||
return {"Authorization": "Bearer test_token"}
|