Fix API compatibility and add user/role/permission and asset import/export
This commit is contained in:
6
backend_new/app/core/__init__.py
Normal file
6
backend_new/app/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
核心模块初始化
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
__all__ = ["settings"]
|
||||
109
backend_new/app/core/config.py
Normal file
109
backend_new/app/core/config.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
应用配置模块
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类"""
|
||||
|
||||
# 应用基本信息
|
||||
APP_NAME: str = Field(default="资产管理系统", description="应用名称")
|
||||
APP_VERSION: str = Field(default="1.0.0", description="应用版本")
|
||||
APP_ENVIRONMENT: str = Field(default="development", description="运行环境")
|
||||
DEBUG: bool = Field(default=False, description="调试模式")
|
||||
API_V1_PREFIX: str = Field(default="/api/v1", description="API V1 前缀")
|
||||
|
||||
# 服务器配置
|
||||
HOST: str = Field(default="0.0.0.0", description="服务器地址")
|
||||
PORT: int = Field(default=8000, description="服务器端口")
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = Field(
|
||||
default="postgresql+asyncpg://postgres:postgres@localhost:5432/asset_management",
|
||||
description="数据库连接URL"
|
||||
)
|
||||
DATABASE_ECHO: bool = Field(default=False, description="是否打印SQL语句")
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL: str = Field(default="redis://localhost:6379/0", description="Redis连接URL")
|
||||
REDIS_MAX_CONNECTIONS: int = Field(default=50, description="Redis最大连接数")
|
||||
|
||||
# JWT配置
|
||||
SECRET_KEY: str = Field(default="your-secret-key-change-in-production", description="JWT密钥")
|
||||
ALGORITHM: str = Field(default="HS256", description="JWT算法")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=15, description="访问令牌过期时间(分钟)")
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=7, description="刷新令牌过期时间(天)")
|
||||
|
||||
# CORS配置
|
||||
CORS_ORIGINS: List[str] = Field(
|
||||
default=["http://localhost:5173", "http://localhost:3000"],
|
||||
description="允许的跨域来源"
|
||||
)
|
||||
CORS_ALLOW_CREDENTIALS: bool = Field(default=True, description="允许携带凭证")
|
||||
CORS_ALLOW_METHODS: List[str] = Field(default=["*"], description="允许的HTTP方法")
|
||||
CORS_ALLOW_HEADERS: List[str] = Field(default=["*"], description="允许的请求头")
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_DIR: str = Field(default="uploads", description="上传文件目录")
|
||||
MAX_UPLOAD_SIZE: int = Field(default=10485760, description="最大上传大小(字节)")
|
||||
ALLOWED_EXTENSIONS: List[str] = Field(
|
||||
default=["png", "jpg", "jpeg", "gif", "pdf", "xlsx", "xls"],
|
||||
description="允许的文件扩展名"
|
||||
)
|
||||
|
||||
# 验证码配置
|
||||
CAPTCHA_EXPIRE_SECONDS: int = Field(default=300, description="验证码过期时间(秒)")
|
||||
CAPTCHA_LENGTH: int = Field(default=4, description="验证码长度")
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = Field(default="INFO", description="日志级别")
|
||||
LOG_FILE: str = Field(default="logs/app.log", description="日志文件路径")
|
||||
LOG_ROTATION: str = Field(default="500 MB", description="日志轮转大小")
|
||||
LOG_RETENTION: str = Field(default="10 days", description="日志保留时间")
|
||||
|
||||
# 分页配置
|
||||
DEFAULT_PAGE_SIZE: int = Field(default=20, description="默认每页数量")
|
||||
MAX_PAGE_SIZE: int = Field(default=100, description="最大每页数量")
|
||||
|
||||
# 二维码配置
|
||||
QR_CODE_DIR: str = Field(default="uploads/qrcodes", description="二维码保存目录")
|
||||
QR_CODE_SIZE: int = Field(default=300, description="二维码尺寸")
|
||||
QR_CODE_BORDER: int = Field(default=2, description="二维码边框")
|
||||
|
||||
@field_validator("CORS_ORIGINS", mode="before")
|
||||
@classmethod
|
||||
def parse_cors_origins(cls, v: str) -> List[str]:
|
||||
"""解析CORS来源"""
|
||||
if isinstance(v, str):
|
||||
return [origin.strip() for origin in v.split(",")]
|
||||
return v
|
||||
|
||||
@field_validator("ALLOWED_EXTENSIONS", mode="before")
|
||||
@classmethod
|
||||
def parse_allowed_extensions(cls, v: str) -> List[str]:
|
||||
"""解析允许的文件扩展名"""
|
||||
if isinstance(v, str):
|
||||
return [ext.strip() for ext in v.split(",")]
|
||||
return v
|
||||
|
||||
@property
|
||||
def is_development(self) -> bool:
|
||||
"""是否为开发环境"""
|
||||
return self.APP_ENVIRONMENT == "development"
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""是否为生产环境"""
|
||||
return self.APP_ENVIRONMENT == "production"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
208
backend_new/app/core/deps.py
Normal file
208
backend_new/app/core/deps.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
依赖注入模块
|
||||
"""
|
||||
from typing import Generator, Optional
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.db.session import async_session_maker
|
||||
from app.core.security import security_manager
|
||||
from app.models.user import User, Role, Permission, UserRole, RolePermission
|
||||
|
||||
# HTTP Bearer认证
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_db() -> Generator:
|
||||
"""
|
||||
获取数据库会话
|
||||
|
||||
Yields:
|
||||
AsyncSession: 数据库会话
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前登录用户
|
||||
|
||||
Args:
|
||||
credentials: HTTP认证凭据
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
User: 当前用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败或用户不存在
|
||||
"""
|
||||
from app.utils.redis_client import redis_client
|
||||
|
||||
token = credentials.credentials
|
||||
|
||||
# 检查Token是否在黑名单中
|
||||
is_blacklisted = await redis_client.get(f"blacklist:{token}")
|
||||
if is_blacklisted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token已失效,请重新登录",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
payload = security_manager.verify_token(token, token_type="access")
|
||||
|
||||
user_id: int = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证凭据",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
from app.crud.user import user_crud
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
if user.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="用户已被禁用"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前活跃用户
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
User: 活跃用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 用户未激活
|
||||
"""
|
||||
if current_user.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="用户账户未激活"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前管理员用户
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
User: 管理员用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 用户不是管理员
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="权限不足,需要管理员权限"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
class PermissionChecker:
|
||||
"""
|
||||
权限检查器
|
||||
"""
|
||||
def __init__(self, required_permission: str):
|
||||
self.required_permission = required_permission
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
检查用户是否有指定权限
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 权限不足
|
||||
"""
|
||||
# 管理员拥有所有权限
|
||||
if current_user.is_admin:
|
||||
return current_user
|
||||
|
||||
# 查询用户的所有权限
|
||||
# 获取用户的角色
|
||||
result = await db.execute(
|
||||
select(Role)
|
||||
.join(UserRole, UserRole.role_id == Role.id)
|
||||
.where(UserRole.user_id == current_user.id)
|
||||
.where(Role.deleted_at.is_(None))
|
||||
)
|
||||
roles = result.scalars().all()
|
||||
|
||||
# 获取角色对应的所有权限编码
|
||||
if not roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="权限不足"
|
||||
)
|
||||
|
||||
role_ids = [role.id for role in roles]
|
||||
result = await db.execute(
|
||||
select(Permission.permission_code)
|
||||
.join(RolePermission, RolePermission.permission_id == Permission.id)
|
||||
.where(RolePermission.role_id.in_(role_ids))
|
||||
.where(Permission.deleted_at.is_(None))
|
||||
)
|
||||
permissions = result.scalars().all()
|
||||
|
||||
# 检查是否有必需的权限
|
||||
if self.required_permission not in permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"需要权限: {self.required_permission}"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# 常用权限检查器
|
||||
require_asset_read = PermissionChecker("asset:asset:read")
|
||||
require_asset_create = PermissionChecker("asset:asset:create")
|
||||
require_asset_update = PermissionChecker("asset:asset:update")
|
||||
require_asset_delete = PermissionChecker("asset:asset:delete")
|
||||
155
backend_new/app/core/exceptions.py
Normal file
155
backend_new/app/core/exceptions.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
自定义异常类
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class BusinessException(Exception):
|
||||
"""业务逻辑异常基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: int = status.HTTP_400_BAD_REQUEST,
|
||||
error_code: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
初始化业务异常
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: HTTP状态码
|
||||
error_code: 业务错误码
|
||||
data: 附加数据
|
||||
"""
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.error_code = error_code
|
||||
self.data = data
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundException(BusinessException):
|
||||
"""资源不存在异常"""
|
||||
|
||||
def __init__(self, resource: str = "资源"):
|
||||
super().__init__(
|
||||
message=f"{resource}不存在",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
error_code="RESOURCE_NOT_FOUND"
|
||||
)
|
||||
|
||||
|
||||
class AlreadyExistsException(BusinessException):
|
||||
"""资源已存在异常"""
|
||||
|
||||
def __init__(self, resource: str = "资源"):
|
||||
super().__init__(
|
||||
message=f"{resource}已存在",
|
||||
code=status.HTTP_409_CONFLICT,
|
||||
error_code="RESOURCE_ALREADY_EXISTS"
|
||||
)
|
||||
|
||||
|
||||
class PermissionDeniedException(BusinessException):
|
||||
"""权限不足异常"""
|
||||
|
||||
def __init__(self, message: str = "权限不足"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
error_code="PERMISSION_DENIED"
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationFailedException(BusinessException):
|
||||
"""认证失败异常"""
|
||||
|
||||
def __init__(self, message: str = "认证失败"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code="AUTHENTICATION_FAILED"
|
||||
)
|
||||
|
||||
|
||||
class ValidationFailedException(BusinessException):
|
||||
"""验证失败异常"""
|
||||
|
||||
def __init__(self, message: str = "数据验证失败", errors: Optional[Dict] = None):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code="VALIDATION_FAILED",
|
||||
data=errors
|
||||
)
|
||||
|
||||
|
||||
class InvalidCredentialsException(AuthenticationFailedException):
|
||||
"""无效凭据异常"""
|
||||
|
||||
def __init__(self, message: str = "用户名或密码错误"):
|
||||
super().__init__(message)
|
||||
self.error_code = "INVALID_CREDENTIALS"
|
||||
|
||||
|
||||
class TokenExpiredException(AuthenticationFailedException):
|
||||
"""令牌过期异常"""
|
||||
|
||||
def __init__(self, message: str = "令牌已过期,请重新登录"):
|
||||
super().__init__(message)
|
||||
self.error_code = "TOKEN_EXPIRED"
|
||||
|
||||
|
||||
class InvalidTokenException(AuthenticationFailedException):
|
||||
"""无效令牌异常"""
|
||||
|
||||
def __init__(self, message: str = "无效的令牌"):
|
||||
super().__init__(message)
|
||||
self.error_code = "INVALID_TOKEN"
|
||||
|
||||
|
||||
class CaptchaException(BusinessException):
|
||||
"""验证码异常"""
|
||||
|
||||
def __init__(self, message: str = "验证码错误"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
error_code="CAPTCHA_ERROR"
|
||||
)
|
||||
|
||||
|
||||
class UserLockedException(BusinessException):
|
||||
"""用户被锁定异常"""
|
||||
|
||||
def __init__(self, message: str = "用户已被锁定,请联系管理员"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
error_code="USER_LOCKED"
|
||||
)
|
||||
|
||||
|
||||
class UserDisabledException(BusinessException):
|
||||
"""用户被禁用异常"""
|
||||
|
||||
def __init__(self, message: str = "用户已被禁用"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
error_code="USER_DISABLED"
|
||||
)
|
||||
|
||||
|
||||
class StateTransitionException(BusinessException):
|
||||
"""状态转换异常"""
|
||||
|
||||
def __init__(self, current_state: str, target_state: str):
|
||||
super().__init__(
|
||||
message=f"无法从状态 '{current_state}' 转换到 '{target_state}'",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
error_code="INVALID_STATE_TRANSITION"
|
||||
)
|
||||
152
backend_new/app/core/response.py
Normal file
152
backend_new/app/core/response.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
统一响应封装模块
|
||||
"""
|
||||
from typing import Any, Generic, TypeVar, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
# 泛型类型变量
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ResponseModel(BaseModel, Generic[T]):
|
||||
"""统一响应模型"""
|
||||
|
||||
code: int = Field(default=200, description="响应状态码")
|
||||
message: str = Field(default="success", description="响应消息")
|
||||
data: Optional[T] = Field(default=None, description="响应数据")
|
||||
timestamp: int = Field(default_factory=lambda: int(datetime.now().timestamp()), description="时间戳")
|
||||
|
||||
@classmethod
|
||||
def success(cls, data: Optional[T] = None, message: str = "success") -> "ResponseModel[T]":
|
||||
"""
|
||||
成功响应
|
||||
|
||||
Args:
|
||||
data: 响应数据
|
||||
message: 响应消息
|
||||
|
||||
Returns:
|
||||
ResponseModel: 响应对象
|
||||
"""
|
||||
return cls(code=200, message=message, data=data)
|
||||
|
||||
@classmethod
|
||||
def error(
|
||||
cls,
|
||||
code: int,
|
||||
message: str,
|
||||
data: Optional[T] = None
|
||||
) -> "ResponseModel[T]":
|
||||
"""
|
||||
错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
message: 错误消息
|
||||
data: 附加数据
|
||||
|
||||
Returns:
|
||||
ResponseModel: 响应对象
|
||||
"""
|
||||
return cls(code=code, message=message, data=data)
|
||||
|
||||
|
||||
class PaginationMeta(BaseModel):
|
||||
"""分页元数据"""
|
||||
|
||||
total: int = Field(..., description="总记录数")
|
||||
page: int = Field(..., ge=1, description="当前页码")
|
||||
page_size: int = Field(..., ge=1, le=100, description="每页记录数")
|
||||
total_pages: int = Field(..., ge=0, description="总页数")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""分页响应模型"""
|
||||
|
||||
total: int = Field(..., description="总记录数")
|
||||
page: int = Field(..., ge=1, description="当前页码")
|
||||
page_size: int = Field(..., ge=1, description="每页记录数")
|
||||
total_pages: int = Field(..., ge=0, description="总页数")
|
||||
items: List[T] = Field(default_factory=list, description="数据列表")
|
||||
|
||||
|
||||
class ValidationError(BaseModel):
|
||||
"""验证错误详情"""
|
||||
|
||||
field: str = Field(..., description="字段名")
|
||||
message: str = Field(..., description="错误消息")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""错误响应模型"""
|
||||
|
||||
code: int = Field(..., description="错误码")
|
||||
message: str = Field(..., description="错误消息")
|
||||
errors: Optional[List[ValidationError]] = Field(default=None, description="错误详情列表")
|
||||
timestamp: int = Field(default_factory=lambda: int(datetime.now().timestamp()), description="时间戳")
|
||||
|
||||
|
||||
def success_response(data: Any = None, message: str = "success") -> dict:
|
||||
"""
|
||||
生成成功响应
|
||||
|
||||
Args:
|
||||
data: 响应数据
|
||||
message: 响应消息
|
||||
|
||||
Returns:
|
||||
dict: 响应字典
|
||||
"""
|
||||
return ResponseModel.success(data=data, message=message).model_dump()
|
||||
|
||||
|
||||
def error_response(code: int, message: str, errors: Optional[List[dict]] = None) -> dict:
|
||||
"""
|
||||
生成错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码
|
||||
message: 错误消息
|
||||
errors: 错误详情列表
|
||||
|
||||
Returns:
|
||||
dict: 响应字典
|
||||
"""
|
||||
error_data = ErrorResponse(
|
||||
code=code,
|
||||
message=message,
|
||||
errors=[ValidationError(**e) for e in errors] if errors else None
|
||||
)
|
||||
return error_data.model_dump()
|
||||
|
||||
|
||||
def paginated_response(
|
||||
items: List[Any],
|
||||
total: int,
|
||||
page: int,
|
||||
page_size: int
|
||||
) -> dict:
|
||||
"""
|
||||
生成分页响应
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
total: 总记录数
|
||||
page: 当前页码
|
||||
page_size: 每页记录数
|
||||
|
||||
Returns:
|
||||
dict: 响应字典
|
||||
"""
|
||||
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
|
||||
|
||||
response = PaginatedResponse(
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
items=items
|
||||
)
|
||||
|
||||
return success_response(data=response.model_dump())
|
||||
178
backend_new/app/core/security.py
Normal file
178
backend_new/app/core/security.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
安全相关工具模块
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi import HTTPException, status
|
||||
from app.core.config import settings
|
||||
|
||||
# 密码加密上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class SecurityManager:
|
||||
"""安全管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.secret_key = settings.SECRET_KEY
|
||||
self.algorithm = settings.ALGORITHM
|
||||
self.access_token_expire_minutes = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
self.refresh_token_expire_days = settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
验证密码
|
||||
|
||||
Args:
|
||||
plain_password: 明文密码
|
||||
hashed_password: 哈希密码
|
||||
|
||||
Returns:
|
||||
bool: 密码是否匹配
|
||||
"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(self, password: str) -> str:
|
||||
"""
|
||||
获取密码哈希值
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
|
||||
Returns:
|
||||
str: 哈希后的密码
|
||||
"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建访问令牌
|
||||
|
||||
Args:
|
||||
data: 要编码的数据
|
||||
expires_delta: 过期时间增量
|
||||
|
||||
Returns:
|
||||
str: JWT令牌
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=self.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": "access"
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建刷新令牌
|
||||
|
||||
Args:
|
||||
data: 要编码的数据
|
||||
expires_delta: 过期时间增量
|
||||
|
||||
Returns:
|
||||
str: JWT令牌
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": "refresh"
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def decode_token(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
解码令牌
|
||||
|
||||
Args:
|
||||
token: JWT令牌
|
||||
|
||||
Returns:
|
||||
Dict: 解码后的数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 令牌无效或过期
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证凭据",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
def verify_token(self, token: str, token_type: str = "access") -> Dict[str, Any]:
|
||||
"""
|
||||
验证令牌
|
||||
|
||||
Args:
|
||||
token: JWT令牌
|
||||
token_type: 令牌类型(access/refresh)
|
||||
|
||||
Returns:
|
||||
Dict: 解码后的数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 令牌无效或类型不匹配
|
||||
"""
|
||||
payload = self.decode_token(token)
|
||||
|
||||
if payload.get("type") != token_type:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"令牌类型不匹配,期望{token_type}"
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
# 创建全局安全管理器实例
|
||||
security_manager = SecurityManager()
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希值(便捷函数)"""
|
||||
return security_manager.get_password_hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码(便捷函数)"""
|
||||
return security_manager.verify_password(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问令牌(便捷函数)"""
|
||||
return security_manager.create_access_token(data, expires_delta)
|
||||
|
||||
|
||||
def create_refresh_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建刷新令牌(便捷函数)"""
|
||||
return security_manager.create_refresh_token(data, expires_delta)
|
||||
Reference in New Issue
Block a user