Files
zcglxt/backend/app/core/deps.py

233 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
依赖注入模块
"""
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy import select
from app.db.session import async_session_maker, sync_session_maker
from app.core.security import security_manager
from app.models.user import User, Role, Permission, UserRole, RolePermission
# HTTP Bearer认证
security = HTTPBearer()
async def get_db() -> Generator:
"""
获取数据库会话
Yields:
AsyncSession: 数据库会话
"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
def get_sync_db() -> Generator[Session, None, None]:
"""
获取同步数据库会话(用于遗留同步查询)
"""
session = sync_session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前登录用户
Args:
credentials: HTTP认证凭据
db: 数据库会话
Returns:
User: 当前用户对象
Raises:
HTTPException: 认证失败或用户不存在
"""
from app.utils.redis_client import redis_client
token = credentials.credentials
# 检查Token是否在黑名单中
is_blacklisted = await redis_client.get(f"blacklist:{token}")
if is_blacklisted:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token已失效请重新登录",
headers={"WWW-Authenticate": "Bearer"}
)
payload = security_manager.verify_token(token, token_type="access")
raw_user_id = payload.get("sub")
if raw_user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据",
headers={"WWW-Authenticate": "Bearer"}
)
try:
user_id: int = int(raw_user_id)
except (TypeError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的用户ID",
headers={"WWW-Authenticate": "Bearer"}
)
from app.crud.user import user_crud
user = await user_crud.get(db, id=user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
if user.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用"
)
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
获取当前活跃用户
Args:
current_user: 当前用户
Returns:
User: 活跃用户对象
Raises:
HTTPException: 用户未激活
"""
if current_user.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户账户未激活"
)
return current_user
async def get_current_admin_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
获取当前管理员用户
Args:
current_user: 当前用户
Returns:
User: 管理员用户对象
Raises:
HTTPException: 用户不是管理员
"""
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足,需要管理员权限"
)
return current_user
class PermissionChecker:
"""
权限检查器
"""
def __init__(self, required_permission: str):
self.required_permission = required_permission
async def __call__(
self,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> User:
"""
检查用户是否有指定权限
Args:
current_user: 当前用户
db: 数据库会话
Returns:
用户对象
Raises:
HTTPException: 权限不足
"""
# 管理员拥有所有权限
if current_user.is_admin:
return current_user
# 查询用户的所有权限
# 获取用户的角色
result = await db.execute(
select(Role)
.join(UserRole, UserRole.role_id == Role.id)
.where(UserRole.user_id == current_user.id)
.where(Role.deleted_at.is_(None))
)
roles = result.scalars().all()
# 获取角色对应的所有权限编码
if not roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足"
)
role_ids = [role.id for role in roles]
result = await db.execute(
select(Permission.permission_code)
.join(RolePermission, RolePermission.permission_id == Permission.id)
.where(RolePermission.role_id.in_(role_ids))
.where(Permission.deleted_at.is_(None))
)
permissions = result.scalars().all()
# 检查是否有必需的权限
if self.required_permission not in permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"需要权限: {self.required_permission}"
)
return current_user
# 常用权限检查器
require_asset_read = PermissionChecker("asset:asset:read")
require_asset_create = PermissionChecker("asset:asset:create")
require_asset_update = PermissionChecker("asset:asset:update")
require_asset_delete = PermissionChecker("asset:asset:delete")