""" 用户CRUD操作 """ from typing import Optional, List, Tuple from sqlalchemy import select, and_, or_ from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession from app.models.user import User, Role, UserRole, Permission, RolePermission from app.schemas.user import UserCreate, UserUpdate, RoleCreate, RoleUpdate from app.core.security import get_password_hash class UserCRUD: """用户CRUD类""" async def get(self, db: AsyncSession, id: int) -> Optional[User]: """ 根据ID获取用户 Args: db: 数据库会话 id: 用户ID Returns: User: 用户对象或None """ result = await db.execute( select(User) .options(selectinload(User.roles).selectinload(Role.permissions)) .where(User.id == id, User.deleted_at.is_(None)) ) return result.scalar_one_or_none() async def get_by_username(self, db: AsyncSession, username: str) -> Optional[User]: """ 根据用户名获取用户 Args: db: 数据库会话 username: 用户名 Returns: User: 用户对象或None """ result = await db.execute( select(User) .options(selectinload(User.roles).selectinload(Role.permissions)) .where(User.username == username, User.deleted_at.is_(None)) ) return result.scalar_one_or_none() async def get_by_email(self, db: AsyncSession, email: str) -> Optional[User]: """ 根据邮箱获取用户 Args: db: 数据库会话 email: 邮箱 Returns: User: 用户对象或None """ result = await db.execute( select(User) .where(User.email == email, User.deleted_at.is_(None)) ) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, skip: int = 0, limit: int = 20, keyword: Optional[str] = None, status: Optional[str] = None, role_id: Optional[int] = None ) -> Tuple[List[User], int]: """ 获取用户列表 Args: db: 数据库会话 skip: 跳过条数 limit: 返回条数 keyword: 搜索关键词 status: 状态筛选 role_id: 角色ID筛选 Returns: Tuple[List[User], int]: 用户列表和总数 """ # 构建查询条件 conditions = [User.deleted_at.is_(None)] if keyword: keyword_pattern = f"%{keyword}%" conditions.append( or_( User.username.ilike(keyword_pattern), User.real_name.ilike(keyword_pattern), User.phone.ilike(keyword_pattern) ) ) if status: conditions.append(User.status == status) # 构建基础查询 query = select(User).options(selectinload(User.roles)).where(*conditions) # 如果需要按角色筛选 if role_id: query = query.join(UserRole).where(UserRole.role_id == role_id) # 按ID降序排序 query = query.order_by(User.id.desc()) # 获取总数 count_query = select(User.id).where(*conditions) if role_id: count_query = count_query.join(UserRole).where(UserRole.role_id == role_id) result = await db.execute(select(User.id).where(*conditions)) total = len(result.all()) # 分页查询 result = await db.execute(query.offset(skip).limit(limit)) users = result.scalars().all() return list(users), total async def create(self, db: AsyncSession, obj_in: UserCreate, creator_id: int) -> User: """ 创建用户 Args: db: 数据库会话 obj_in: 创建数据 creator_id: 创建人ID Returns: User: 创建的用户对象 """ # 检查用户名是否已存在 existing_user = await self.get_by_username(db, obj_in.username) if existing_user: raise ValueError("用户名已存在") # 检查邮箱是否已存在 if obj_in.email: existing_email = await self.get_by_email(db, obj_in.email) if existing_email: raise ValueError("邮箱已存在") # 创建用户对象 db_obj = User( username=obj_in.username, password_hash=get_password_hash(obj_in.password), real_name=obj_in.real_name, email=obj_in.email, phone=obj_in.phone, created_by=creator_id ) db.add(db_obj) await db.flush() await db.refresh(db_obj) # 分配角色 for role_id in obj_in.role_ids: user_role = UserRole( user_id=db_obj.id, role_id=role_id, created_by=creator_id ) db.add(user_role) await db.commit() await db.refresh(db_obj) return await self.get(db, db_obj.id) async def update( self, db: AsyncSession, db_obj: User, obj_in: UserUpdate, updater_id: int ) -> User: """ 更新用户 Args: db: 数据库会话 db_obj: 数据库中的用户对象 obj_in: 更新数据 updater_id: 更新人ID Returns: User: 更新后的用户对象 """ update_data = obj_in.model_dump(exclude_unset=True) # 检查邮箱是否已被其他用户使用 if "email" in update_data and update_data["email"]: existing_user = await db.execute( select(User).where( User.email == update_data["email"], User.id != db_obj.id, User.deleted_at.is_(None) ) ) if existing_user.scalar_one_or_none(): raise ValueError("邮箱已被使用") # 更新字段 for field, value in update_data.items(): if field == "role_ids": continue setattr(db_obj, field, value) db_obj.updated_by = updater_id # 更新角色 if "role_ids" in update_data: # 删除旧角色 await db.execute( select(UserRole).where(UserRole.user_id == db_obj.id) ) old_roles = await db.execute( select(UserRole).where(UserRole.user_id == db_obj.id) ) for old_role in old_roles.scalars().all(): await db.delete(old_role) # 添加新角色 for role_id in update_data["role_ids"]: user_role = UserRole( user_id=db_obj.id, role_id=role_id, created_by=updater_id ) db.add(user_role) await db.commit() await db.refresh(db_obj) return await self.get(db, db_obj.id) async def delete(self, db: AsyncSession, id: int, deleter_id: int) -> bool: """ 删除用户(软删除) Args: db: 数据库会话 id: 用户ID deleter_id: 删除人ID Returns: bool: 是否删除成功 """ db_obj = await self.get(db, id) if not db_obj: return False db_obj.deleted_at = datetime.utcnow() db_obj.deleted_by = deleter_id await db.commit() return True async def update_password( self, db: AsyncSession, user: User, new_password: str ) -> bool: """ 更新用户密码 Args: db: 数据库会话 user: 用户对象 new_password: 新密码 Returns: bool: 是否更新成功 """ user.password_hash = get_password_hash(new_password) user.login_fail_count = 0 user.locked_until = None await db.commit() return True async def update_last_login(self, db: AsyncSession, user: User) -> bool: """ 更新用户最后登录时间 Args: db: 数据库会话 user: 用户对象 Returns: bool: 是否更新成功 """ from datetime import datetime user.last_login_at = datetime.utcnow() user.login_fail_count = 0 await db.commit() return True class RoleCRUD: """角色CRUD类""" async def get(self, db: AsyncSession, id: int) -> Optional[Role]: """根据ID获取角色""" result = await db.execute( select(Role) .options(selectinload(Role.permissions)) .where(Role.id == id, Role.deleted_at.is_(None)) ) return result.scalar_one_or_none() async def get_by_code(self, db: AsyncSession, role_code: str) -> Optional[Role]: """根据代码获取角色""" result = await db.execute( select(Role).where(Role.role_code == role_code, Role.deleted_at.is_(None)) ) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, status: Optional[str] = None ) -> List[Role]: """获取角色列表""" conditions = [Role.deleted_at.is_(None)] if status: conditions.append(Role.status == status) result = await db.execute( select(Role) .options(selectinload(Role.permissions)) .where(*conditions) .order_by(Role.sort_order, Role.id) ) return list(result.scalars().all()) async def create(self, db: AsyncSession, obj_in: RoleCreate, creator_id: int) -> Role: """创建角色""" # 检查代码是否已存在 existing_role = await self.get_by_code(db, obj_in.role_code) if existing_role: raise ValueError("角色代码已存在") db_obj = Role( role_name=obj_in.role_name, role_code=obj_in.role_code, description=obj_in.description, created_by=creator_id ) db.add(db_obj) await db.flush() await db.refresh(db_obj) # 分配权限 for permission_id in obj_in.permission_ids: role_permission = RolePermission( role_id=db_obj.id, permission_id=permission_id, created_by=creator_id ) db.add(role_permission) await db.commit() return await self.get(db, db_obj.id) async def update( self, db: AsyncSession, db_obj: Role, obj_in: RoleUpdate, updater_id: int ) -> Role: """更新角色""" update_data = obj_in.model_dump(exclude_unset=True) for field, value in update_data.items(): if field == "permission_ids": continue setattr(db_obj, field, value) db_obj.updated_by = updater_id # 更新权限 if "permission_ids" in update_data: # 删除旧权限 old_permissions = await db.execute( select(RolePermission).where(RolePermission.role_id == db_obj.id) ) for old_perm in old_permissions.scalars().all(): await db.delete(old_perm) # 添加新权限 for permission_id in update_data["permission_ids"]: role_permission = RolePermission( role_id=db_obj.id, permission_id=permission_id, created_by=updater_id ) db.add(role_permission) await db.commit() return await self.get(db, db_obj.id) async def delete(self, db: AsyncSession, id: int) -> bool: """删除角色(软删除)""" db_obj = await self.get(db, id) if not db_obj: return False db_obj.deleted_at = datetime.utcnow() await db.commit() return True # 创建CRUD实例 user_crud = UserCRUD() role_crud = RoleCRUD()