233 lines
6.1 KiB
Python
233 lines
6.1 KiB
Python
"""
|
||
依赖注入模块
|
||
"""
|
||
from typing import Generator, Optional
|
||
from fastapi import Depends, HTTPException, status
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import select
|
||
from app.db.session import async_session_maker, sync_session_maker
|
||
from app.core.security import security_manager
|
||
from app.models.user import User, Role, Permission, UserRole, RolePermission
|
||
|
||
# HTTP Bearer认证
|
||
security = HTTPBearer()
|
||
|
||
|
||
async def get_db() -> Generator:
|
||
"""
|
||
获取数据库会话
|
||
|
||
Yields:
|
||
AsyncSession: 数据库会话
|
||
"""
|
||
async with async_session_maker() as session:
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
def get_sync_db() -> Generator[Session, None, None]:
|
||
"""
|
||
获取同步数据库会话(用于遗留同步查询)
|
||
"""
|
||
session = sync_session_maker()
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except Exception:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
async def get_current_user(
|
||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> User:
|
||
"""
|
||
获取当前登录用户
|
||
|
||
Args:
|
||
credentials: HTTP认证凭据
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
User: 当前用户对象
|
||
|
||
Raises:
|
||
HTTPException: 认证失败或用户不存在
|
||
"""
|
||
from app.utils.redis_client import redis_client
|
||
|
||
token = credentials.credentials
|
||
|
||
# 检查Token是否在黑名单中
|
||
is_blacklisted = await redis_client.get(f"blacklist:{token}")
|
||
if is_blacklisted:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Token已失效,请重新登录",
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
|
||
payload = security_manager.verify_token(token, token_type="access")
|
||
|
||
raw_user_id = payload.get("sub")
|
||
if raw_user_id is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
try:
|
||
user_id: int = int(raw_user_id)
|
||
except (TypeError, ValueError):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的用户ID",
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
|
||
from app.crud.user import user_crud
|
||
user = await user_crud.get(db, id=user_id)
|
||
|
||
if user is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="用户不存在"
|
||
)
|
||
|
||
if user.status != "active":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户已被禁用"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
async def get_current_active_user(
|
||
current_user: User = Depends(get_current_user)
|
||
) -> User:
|
||
"""
|
||
获取当前活跃用户
|
||
|
||
Args:
|
||
current_user: 当前用户
|
||
|
||
Returns:
|
||
User: 活跃用户对象
|
||
|
||
Raises:
|
||
HTTPException: 用户未激活
|
||
"""
|
||
if current_user.status != "active":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户账户未激活"
|
||
)
|
||
return current_user
|
||
|
||
|
||
async def get_current_admin_user(
|
||
current_user: User = Depends(get_current_user)
|
||
) -> User:
|
||
"""
|
||
获取当前管理员用户
|
||
|
||
Args:
|
||
current_user: 当前用户
|
||
|
||
Returns:
|
||
User: 管理员用户对象
|
||
|
||
Raises:
|
||
HTTPException: 用户不是管理员
|
||
"""
|
||
if not current_user.is_admin:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="权限不足,需要管理员权限"
|
||
)
|
||
return current_user
|
||
|
||
|
||
class PermissionChecker:
|
||
"""
|
||
权限检查器
|
||
"""
|
||
def __init__(self, required_permission: str):
|
||
self.required_permission = required_permission
|
||
|
||
async def __call__(
|
||
self,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> User:
|
||
"""
|
||
检查用户是否有指定权限
|
||
|
||
Args:
|
||
current_user: 当前用户
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
用户对象
|
||
|
||
Raises:
|
||
HTTPException: 权限不足
|
||
"""
|
||
# 管理员拥有所有权限
|
||
if current_user.is_admin:
|
||
return current_user
|
||
|
||
# 查询用户的所有权限
|
||
# 获取用户的角色
|
||
result = await db.execute(
|
||
select(Role)
|
||
.join(UserRole, UserRole.role_id == Role.id)
|
||
.where(UserRole.user_id == current_user.id)
|
||
.where(Role.deleted_at.is_(None))
|
||
)
|
||
roles = result.scalars().all()
|
||
|
||
# 获取角色对应的所有权限编码
|
||
if not roles:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="权限不足"
|
||
)
|
||
|
||
role_ids = [role.id for role in roles]
|
||
result = await db.execute(
|
||
select(Permission.permission_code)
|
||
.join(RolePermission, RolePermission.permission_id == Permission.id)
|
||
.where(RolePermission.role_id.in_(role_ids))
|
||
.where(Permission.deleted_at.is_(None))
|
||
)
|
||
permissions = result.scalars().all()
|
||
|
||
# 检查是否有必需的权限
|
||
if self.required_permission not in permissions:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail=f"需要权限: {self.required_permission}"
|
||
)
|
||
|
||
return current_user
|
||
|
||
|
||
# 常用权限检查器
|
||
require_asset_read = PermissionChecker("asset:asset:read")
|
||
require_asset_create = PermissionChecker("asset:asset:create")
|
||
require_asset_update = PermissionChecker("asset:asset:update")
|
||
require_asset_delete = PermissionChecker("asset:asset:delete")
|