""" 依赖注入模块 """ from typing import Generator, Optional from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.db.session import async_session_maker from app.core.security import security_manager from app.models.user import User, Role, Permission, UserRole, RolePermission # HTTP Bearer认证 security = HTTPBearer() async def get_db() -> Generator: """ 获取数据库会话 Yields: AsyncSession: 数据库会话 """ async with async_session_maker() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ) -> User: """ 获取当前登录用户 Args: credentials: HTTP认证凭据 db: 数据库会话 Returns: User: 当前用户对象 Raises: HTTPException: 认证失败或用户不存在 """ from app.utils.redis_client import redis_client token = credentials.credentials # 检查Token是否在黑名单中 is_blacklisted = await redis_client.get(f"blacklist:{token}") if is_blacklisted: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已失效,请重新登录", headers={"WWW-Authenticate": "Bearer"} ) payload = security_manager.verify_token(token, token_type="access") user_id: int = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭据", headers={"WWW-Authenticate": "Bearer"} ) from app.crud.user import user_crud user = await user_crud.get(db, id=user_id) if user is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在" ) if user.status != "active": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用" ) return user async def get_current_active_user( current_user: User = Depends(get_current_user) ) -> User: """ 获取当前活跃用户 Args: current_user: 当前用户 Returns: User: 活跃用户对象 Raises: HTTPException: 用户未激活 """ if current_user.status != "active": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="用户账户未激活" ) return current_user async def get_current_admin_user( current_user: User = Depends(get_current_user) ) -> User: """ 获取当前管理员用户 Args: current_user: 当前用户 Returns: User: 管理员用户对象 Raises: HTTPException: 用户不是管理员 """ if not current_user.is_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足,需要管理员权限" ) return current_user class PermissionChecker: """ 权限检查器 """ def __init__(self, required_permission: str): self.required_permission = required_permission async def __call__( self, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ) -> User: """ 检查用户是否有指定权限 Args: current_user: 当前用户 db: 数据库会话 Returns: 用户对象 Raises: HTTPException: 权限不足 """ # 管理员拥有所有权限 if current_user.is_admin: return current_user # 查询用户的所有权限 # 获取用户的角色 result = await db.execute( select(Role) .join(UserRole, UserRole.role_id == Role.id) .where(UserRole.user_id == current_user.id) .where(Role.deleted_at.is_(None)) ) roles = result.scalars().all() # 获取角色对应的所有权限编码 if not roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足" ) role_ids = [role.id for role in roles] result = await db.execute( select(Permission.permission_code) .join(RolePermission, RolePermission.permission_id == Permission.id) .where(RolePermission.role_id.in_(role_ids)) .where(Permission.deleted_at.is_(None)) ) permissions = result.scalars().all() # 检查是否有必需的权限 if self.required_permission not in permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"需要权限: {self.required_permission}" ) return current_user # 常用权限检查器 require_asset_read = PermissionChecker("asset:asset:read") require_asset_create = PermissionChecker("asset:asset:create") require_asset_update = PermissionChecker("asset:asset:update") require_asset_delete = PermissionChecker("asset:asset:delete")