""" 文件管理CRUD操作 """ from typing import List, Optional, Dict, Any, Tuple from sqlalchemy.orm import Session from sqlalchemy import and_, or_, func, desc from datetime import datetime, timedelta from app.models.file_management import UploadedFile class CRUDUploadedFile: """上传文件CRUD操作""" def create(self, db: Session, *, obj_in: Dict[str, Any]) -> UploadedFile: """创建文件记录""" db_obj = UploadedFile(**obj_in) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def get(self, db: Session, id: int) -> Optional[UploadedFile]: """根据ID获取文件""" return db.query(UploadedFile).filter( and_( UploadedFile.id == id, UploadedFile.is_deleted == 0 ) ).first() def get_by_share_code(self, db: Session, share_code: str) -> Optional[UploadedFile]: """根据分享码获取文件""" now = datetime.utcnow() return db.query(UploadedFile).filter( and_( UploadedFile.share_code == share_code, UploadedFile.is_deleted == 0, or_( UploadedFile.share_expire_time.is_(None), UploadedFile.share_expire_time > now ) ) ).first() def get_multi( self, db: Session, *, skip: int = 0, limit: int = 20, keyword: Optional[str] = None, file_type: Optional[str] = None, uploader_id: Optional[int] = None, start_date: Optional[str] = None, end_date: Optional[str] = None ) -> Tuple[List[UploadedFile], int]: """获取文件列表""" query = db.query(UploadedFile).filter(UploadedFile.is_deleted == 0) # 关键词搜索 if keyword: query = query.filter( or_( UploadedFile.original_name.like(f"%{keyword}%"), UploadedFile.file_name.like(f"%{keyword}%") ) ) # 文件类型筛选 if file_type: query = query.filter(UploadedFile.file_type == file_type) # 上传者筛选 if uploader_id: query = query.filter(UploadedFile.uploader_id == uploader_id) # 日期范围筛选 if start_date: start = datetime.strptime(start_date, "%Y-%m-%d") query = query.filter(UploadedFile.upload_time >= start) if end_date: end = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) query = query.filter(UploadedFile.upload_time < end) # 获取总数 total = query.count() # 分页 items = query.order_by(desc(UploadedFile.upload_time)).offset(skip).limit(limit).all() return items, total def update(self, db: Session, *, db_obj: UploadedFile, obj_in: Dict[str, Any]) -> UploadedFile: """更新文件记录""" for field, value in obj_in.items(): setattr(db_obj, field, value) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def delete(self, db: Session, *, db_obj: UploadedFile, deleter_id: int) -> UploadedFile: """软删除文件""" db_obj.is_deleted = 1 db_obj.deleted_at = datetime.utcnow() db_obj.deleted_by = deleter_id db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def delete_batch(self, db: Session, *, file_ids: List[int], deleter_id: int) -> int: """批量删除文件""" now = datetime.utcnow() count = db.query(UploadedFile).filter( and_( UploadedFile.id.in_(file_ids), UploadedFile.is_deleted == 0 ) ).update({ "is_deleted": 1, "deleted_at": now, "deleted_by": deleter_id }, synchronize_session=False) db.commit() return count def increment_download_count(self, db: Session, *, file_id: int) -> int: """增加下载次数""" file_obj = self.get(db, file_id) if file_obj: file_obj.download_count = (file_obj.download_count or 0) + 1 db.add(file_obj) db.commit() return file_obj.download_count return 0 def generate_share_code(self, db: Session, *, file_id: int, expire_days: int = 7) -> str: """生成分享码""" import secrets import string file_obj = self.get(db, file_id) if not file_obj: return None # 生成随机分享码 alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits share_code = ''.join(secrets.choice(alphabet) for _ in range(16)) # 设置过期时间 expire_time = datetime.utcnow() + timedelta(days=expire_days) # 更新文件记录 self.update(db, db_obj=file_obj, obj_in={ "share_code": share_code, "share_expire_time": expire_time }) return share_code def get_statistics( self, db: Session, *, uploader_id: Optional[int] = None ) -> Dict[str, Any]: """获取文件统计信息""" # 基础查询 query = db.query(UploadedFile).filter(UploadedFile.is_deleted == 0) if uploader_id: query = query.filter(UploadedFile.uploader_id == uploader_id) # 总文件数和总大小 total_stats = query.with_entities( func.count(UploadedFile.id).label('count'), func.sum(UploadedFile.file_size).label('size') ).first() # 文件类型分布 type_dist = query.with_entities( UploadedFile.file_type, func.count(UploadedFile.id).label('count') ).group_by(UploadedFile.file_type).all() type_distribution = {file_type: count for file_type, count in type_dist} # 今日上传数 today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) upload_today = query.filter(UploadedFile.upload_time >= today_start).count() # 本周上传数 week_start = today_start - timedelta(days=today_start.weekday()) upload_this_week = query.filter(UploadedFile.upload_time >= week_start).count() # 本月上传数 month_start = today_start.replace(day=1) upload_this_month = query.filter(UploadedFile.upload_time >= month_start).count() # 上传排行 uploader_ranking = query.with_entities( UploadedFile.uploader_id, func.count(UploadedFile.id).label('count') ).group_by(UploadedFile.uploader_id).order_by(desc('count')).limit(10).all() # 转换为人类可读的文件大小 total_size = total_stats.size or 0 total_size_human = self._format_size(total_size) return { "total_files": total_stats.count or 0, "total_size": total_size, "total_size_human": total_size_human, "type_distribution": type_distribution, "upload_today": upload_today, "upload_this_week": upload_this_week, "upload_this_month": upload_this_month, "top_uploaders": [{"uploader_id": uid, "count": count} for uid, count in uploader_ranking] } @staticmethod def _format_size(size_bytes: int) -> str: """格式化文件大小""" for unit in ['B', 'KB', 'MB', 'GB', 'TB']: if size_bytes < 1024.0: return f"{size_bytes:.2f} {unit}" size_bytes /= 1024.0 return f"{size_bytes:.2f} PB" # 创建CRUD实例 uploaded_file = CRUDUploadedFile()