""" 文件存储服务 """ import os import uuid import secrets import mimetypes from typing import Optional, Dict, Any, List, Tuple from pathlib import Path from datetime import datetime, timedelta from fastapi import UploadFile, HTTPException, status from sqlalchemy.orm import Session from PIL import Image import io from app.models.file_management import UploadedFile from app.schemas.file_management import ( UploadedFileCreate, FileUploadResponse, FileShareResponse, FileStatistics ) from app.crud.file_management import uploaded_file as crud_uploaded_file class FileService: """文件存储服务""" # 允许的文件类型白名单 ALLOWED_MIME_TYPES = { # 图片 'image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/webp', 'image/svg+xml', # 文档 'application/pdf', 'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'application/vnd.ms-powerpoint', 'application/vnd.openxmlformats-officedocument.presentationml.presentation', 'text/plain', 'text/csv', # 压缩包 'application/zip', 'application/x-rar-compressed', 'application/x-7z-compressed', # 其他 'application/json', 'application/xml', 'text/xml' } # 文件大小限制(字节)- 默认100MB MAX_FILE_SIZE = 100 * 1024 * 1024 # 图片文件大小限制 - 默认10MB MAX_IMAGE_SIZE = 10 * 1024 * 1024 # Magic Numbers for file validation MAGIC_NUMBERS = { b'\xFF\xD8\xFF': 'image/jpeg', b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A': 'image/png', b'GIF87a': 'image/gif', b'GIF89a': 'image/gif', b'%PDF': 'application/pdf', b'PK\x03\x04': 'application/zip', } def __init__(self, base_upload_dir: str = "uploads"): self.base_upload_dir = Path(base_upload_dir) self.ensure_upload_dirs() def ensure_upload_dirs(self): """确保上传目录存在""" directories = [ self.base_upload_dir, self.base_upload_dir / "images", self.base_upload_dir / "documents", self.base_upload_dir / "thumbnails", self.base_upload_dir / "temp", ] for directory in directories: directory.mkdir(parents=True, exist_ok=True) def validate_file_type(self, file: UploadFile) -> bool: """验证文件类型""" # 检查MIME类型 if file.content_type not in self.ALLOWED_MIME_TYPES: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的文件类型: {file.content_type}" ) return True def validate_file_size(self, file: UploadFile) -> bool: """验证文件大小""" # 先检查是否是图片 if file.content_type and file.content_type.startswith('image/'): max_size = self.MAX_IMAGE_SIZE else: max_size = self.MAX_FILE_SIZE # 读取文件内容检查大小 content = file.file.read() file.file.seek(0) # 重置文件指针 if len(content) > max_size: # 转换为MB size_mb = max_size / (1024 * 1024) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"文件大小超过限制: {size_mb:.0f}MB" ) return True def validate_file_content(self, content: bytes) -> str: """验证文件内容(Magic Number)""" for magic, mime_type in self.MAGIC_NUMBERS.items(): if content.startswith(magic): return mime_type return None async def upload_file( self, db: Session, file: UploadFile, uploader_id: int, remark: Optional[str] = None ) -> UploadedFile: """ 上传文件 Args: db: 数据库会话 file: 上传的文件 uploader_id: 上传者ID remark: 备注 Returns: UploadedFile: 创建的文件记录 """ # 验证文件类型 self.validate_file_type(file) # 验证文件大小 self.validate_file_size(file) # 读取文件内容 content = await file.read() # 验证文件内容 detected_mime = self.validate_file_content(content) if detected_mime and detected_mime != file.content_type: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"文件内容与扩展名不匹配" ) # 生成文件名 file_ext = self.get_file_extension(file.filename) unique_filename = f"{uuid.uuid4()}{file_ext}" # 确定存储路径 upload_date = datetime.utcnow() date_dir = upload_date.strftime("%Y/%m/%d") save_dir = self.base_upload_dir / date_dir save_dir.mkdir(parents=True, exist_ok=True) file_path = save_dir / unique_filename # 保存文件 with open(file_path, "wb") as f: f.write(content) # 生成缩略图(如果是图片) thumbnail_path = None if file.content_type and file.content_type.startswith('image/'): thumbnail_path = self.generate_thumbnail(content, unique_filename, date_dir) # 创建数据库记录 file_create = UploadedFileCreate( file_name=unique_filename, original_name=file.filename, file_path=str(file_path), file_size=len(content), file_type=file.content_type, file_ext=file_ext.lstrip('.'), uploader_id=uploader_id ) db_obj = crud_uploaded_file.create(db, obj_in=file_create.dict()) # 更新缩略图路径 if thumbnail_path: crud_uploaded_file.update(db, db_obj=db_obj, obj_in={"thumbnail_path": thumbnail_path}) # 模拟病毒扫描 self._scan_virus(file_path) return db_obj def generate_thumbnail( self, content: bytes, filename: str, date_dir: str ) -> Optional[str]: """生成缩略图""" try: # 打开图片 image = Image.open(io.BytesIO(content)) # 转换为RGB(如果是RGBA) if image.mode in ('RGBA', 'P'): image = image.convert('RGB') # 创建缩略图 thumbnail_size = (200, 200) image.thumbnail(thumbnail_size, Image.Resampling.LANCZOS) # 保存缩略图 thumbnail_dir = self.base_upload_dir / "thumbnails" / date_dir thumbnail_dir.mkdir(parents=True, exist_ok=True) thumbnail_name = f"thumb_{filename}" thumbnail_path = thumbnail_dir / thumbnail_name image.save(thumbnail_path, 'JPEG', quality=85) return str(thumbnail_path) except Exception as e: print(f"生成缩略图失败: {e}") return None def get_file_path(self, file_obj: UploadedFile) -> Path: """获取文件路径""" return Path(file_obj.file_path) def file_exists(self, file_obj: UploadedFile) -> bool: """检查文件是否存在""" file_path = self.get_file_path(file_obj) return file_path.exists() and file_path.is_file() def delete_file_from_disk(self, file_obj: UploadedFile) -> bool: """从磁盘删除文件""" try: file_path = self.get_file_path(file_obj) if file_path.exists(): file_path.unlink() # 删除缩略图 if file_obj.thumbnail_path: thumbnail_path = Path(file_obj.thumbnail_path) if thumbnail_path.exists(): thumbnail_path.unlink() return True except Exception as e: print(f"删除文件失败: {e}") return False def generate_share_link( self, db: Session, file_id: int, expire_days: int = 7, base_url: str = "http://localhost:8000" ) -> FileShareResponse: """ 生成分享链接 Args: db: 数据库会话 file_id: 文件ID expire_days: 有效期(天) base_url: 基础URL Returns: FileShareResponse: 分享链接信息 """ # 生成分享码 share_code = crud_uploaded_file.generate_share_code( db, file_id=file_id, expire_days=expire_days ) if not share_code: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在" ) # 获取文件信息 file_obj = crud_uploaded_file.get(db, file_id) expire_time = file_obj.share_expire_time # 生成分享URL share_url = f"{base_url}/api/v1/files/share/{share_code}" return FileShareResponse( share_code=share_code, share_url=share_url, expire_time=expire_time ) def get_shared_file(self, db: Session, share_code: str) -> Optional[UploadedFile]: """通过分享码获取文件""" return crud_uploaded_file.get_by_share_code(db, share_code) def get_statistics( self, db: Session, uploader_id: Optional[int] = None ) -> FileStatistics: """获取文件统计信息""" stats = crud_uploaded_file.get_statistics(db, uploader_id=uploader_id) return FileStatistics(**stats) @staticmethod def get_file_extension(filename: str) -> str: """获取文件扩展名""" return os.path.splitext(filename)[1] @staticmethod def get_mime_type(filename: str) -> str: """获取MIME类型""" mime_type, _ = mimetypes.guess_type(filename) return mime_type or 'application/octet-stream' @staticmethod def _scan_virus(file_path: Path) -> bool: """ 模拟病毒扫描 实际生产环境应集成专业杀毒软件如: - ClamAV - VirusTotal API - Windows Defender """ # 模拟扫描 import time time.sleep(0.1) # 模拟扫描时间 return True # 假设文件安全 # 分片上传管理 class ChunkUploadManager: """分片上传管理器""" def __init__(self): self.uploads: Dict[str, Dict[str, Any]] = {} def init_upload( self, file_name: str, file_size: int, file_type: str, total_chunks: int, file_hash: Optional[str] = None ) -> str: """初始化分片上传""" upload_id = str(uuid.uuid4()) self.uploads[upload_id] = { "file_name": file_name, "file_size": file_size, "file_type": file_type, "total_chunks": total_chunks, "file_hash": file_hash, "uploaded_chunks": [], "created_at": datetime.utcnow() } return upload_id def save_chunk( self, upload_id: str, chunk_index: int, chunk_data: bytes ) -> bool: """保存分片""" if upload_id not in self.uploads: return False upload_info = self.uploads[upload_id] # 保存分片到临时文件 temp_dir = Path("uploads/temp") temp_dir.mkdir(parents=True, exist_ok=True) chunk_filename = f"{upload_id}_chunk_{chunk_index}" chunk_path = temp_dir / chunk_filename with open(chunk_path, "wb") as f: f.write(chunk_data) # 记录已上传的分片 if chunk_index not in upload_info["uploaded_chunks"]: upload_info["uploaded_chunks"].append(chunk_index) return True def is_complete(self, upload_id: str) -> bool: """检查是否所有分片都已上传""" if upload_id not in self.uploads: return False upload_info = self.uploads[upload_id] return len(upload_info["uploaded_chunks"]) == upload_info["total_chunks"] def merge_chunks( self, db: Session, upload_id: str, uploader_id: int, file_service: FileService ) -> UploadedFile: """合并分片""" if upload_id not in self.uploads: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="上传会话不存在" ) if not self.is_complete(upload_id): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="分片未全部上传" ) upload_info = self.uploads[upload_id] # 合并分片 temp_dir = Path("uploads/temp") merged_content = b"" for i in range(upload_info["total_chunks"]): chunk_filename = f"{upload_id}_chunk_{i}" chunk_path = temp_dir / chunk_filename if not chunk_path.exists(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"分片 {i} 不存在" ) with open(chunk_path, "rb") as f: merged_content += f.read() # 验证文件大小 if len(merged_content) != upload_info["file_size"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件大小不匹配" ) # 验证文件哈希(如果提供) if upload_info["file_hash"]: import hashlib file_hash = hashlib.md5(merged_content).hexdigest() if file_hash != upload_info["file_hash"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件哈希不匹配" ) # 保存文件 file_ext = Path(upload_info["file_name"]).suffix unique_filename = f"{uuid.uuid4()}{file_ext}" upload_date = datetime.utcnow() date_dir = upload_date.strftime("%Y/%m/%d") save_dir = Path("uploads") / date_dir save_dir.mkdir(parents=True, exist_ok=True) file_path = save_dir / unique_filename with open(file_path, "wb") as f: f.write(merged_content) # 清理临时文件 self.cleanup_upload(upload_id) # 创建数据库记录 from app.schemas.file_management import UploadedFileCreate file_create = UploadedFileCreate( file_name=unique_filename, original_name=upload_info["file_name"], file_path=str(file_path), file_size=upload_info["file_size"], file_type=upload_info["file_type"], file_ext=file_ext.lstrip('.'), uploader_id=uploader_id ) db_obj = crud_uploaded_file.create(db, obj_in=file_create.dict()) return db_obj def cleanup_upload(self, upload_id: str): """清理上传会话""" if upload_id in self.uploads: del self.uploads[upload_id] # 清理临时分片文件 temp_dir = Path("uploads/temp") for chunk_file in temp_dir.glob(f"{upload_id}_chunk_*"): chunk_file.unlink() # 创建服务实例 file_service = FileService() chunk_upload_manager = ChunkUploadManager()