""" 依赖注入模块 """ from typing import Generator, Optional from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy import select from app.db.session import async_session_maker, sync_session_maker from app.core.security import security_manager from app.models.user import User, Role, Permission, UserRole, RolePermission # HTTP Bearer认证 security = HTTPBearer() async def get_db() -> Generator: """ 获取数据库会话 Yields: AsyncSession: 数据库会话 """ async with async_session_maker() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() def get_sync_db() -> Generator[Session, None, None]: """ 获取同步数据库会话(用于遗留同步查询) """ session = sync_session_maker() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ) -> User: """ 获取当前登录用户 Args: credentials: HTTP认证凭据 db: 数据库会话 Returns: User: 当前用户对象 Raises: HTTPException: 认证失败或用户不存在 """ from app.utils.redis_client import redis_client token = credentials.credentials # 检查Token是否在黑名单中 is_blacklisted = await redis_client.get(f"blacklist:{token}") if is_blacklisted: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已失效,请重新登录", headers={"WWW-Authenticate": "Bearer"} ) payload = security_manager.verify_token(token, token_type="access") raw_user_id = payload.get("sub") if raw_user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据", headers={"WWW-Authenticate": "Bearer"} ) try: user_id: int = int(raw_user_id) except (TypeError, ValueError): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户ID", headers={"WWW-Authenticate": "Bearer"} ) from app.crud.user import user_crud user = await user_crud.get(db, id=user_id) if user is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在" ) if user.status != "active": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用" ) return user async def get_current_active_user( current_user: User = Depends(get_current_user) ) -> User: """ 获取当前活跃用户 Args: current_user: 当前用户 Returns: User: 活跃用户对象 Raises: HTTPException: 用户未激活 """ if current_user.status != "active": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户账户未激活" ) return current_user async def get_current_admin_user( current_user: User = Depends(get_current_user) ) -> User: """ 获取当前管理员用户 Args: current_user: 当前用户 Returns: User: 管理员用户对象 Raises: HTTPException: 用户不是管理员 """ if not current_user.is_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足,需要管理员权限" ) return current_user class PermissionChecker: """ 权限检查器 """ def __init__(self, required_permission: str): self.required_permission = required_permission async def __call__( self, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ) -> User: """ 检查用户是否有指定权限 Args: current_user: 当前用户 db: 数据库会话 Returns: 用户对象 Raises: HTTPException: 权限不足 """ # 管理员拥有所有权限 if current_user.is_admin: return current_user # 查询用户的所有权限 # 获取用户的角色 result = await db.execute( select(Role) .join(UserRole, UserRole.role_id == Role.id) .where(UserRole.user_id == current_user.id) .where(Role.deleted_at.is_(None)) ) roles = result.scalars().all() # 获取角色对应的所有权限编码 if not roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足" ) role_ids = [role.id for role in roles] result = await db.execute( select(Permission.permission_code) .join(RolePermission, RolePermission.permission_id == Permission.id) .where(RolePermission.role_id.in_(role_ids)) .where(Permission.deleted_at.is_(None)) ) permissions = result.scalars().all() # 检查是否有必需的权限 if self.required_permission not in permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"需要权限: {self.required_permission}" ) return current_user # 常用权限检查器 require_asset_read = PermissionChecker("asset:asset:read") require_asset_create = PermissionChecker("asset:asset:create") require_asset_update = PermissionChecker("asset:asset:update") require_asset_delete = PermissionChecker("asset:asset:delete")