""" 用户CRUD操作 - 匹配实际数据库结构 """ from typing import Optional, List, Tuple from datetime import datetime from sqlalchemy import select, and_, or_, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.user import User from app.core.security import get_password_hash class UserCRUD: """用户CRUD类""" async def get(self, db: AsyncSession, id: int) -> Optional[User]: """ 根据ID获取用户 Args: db: 数据库会话 id: 用户ID Returns: User: 用户对象或None """ result = await db.execute( select(User).where(User.id == id) ) return result.scalar_one_or_none() async def get_by_username(self, db: AsyncSession, username: str) -> Optional[User]: """ 根据用户名获取用户 Args: db: 数据库会话 username: 用户名 Returns: User: 用户对象或None """ result = await db.execute( select(User).where(User.username == username) ) return result.scalar_one_or_none() async def get_by_email(self, db: AsyncSession, email: str) -> Optional[User]: """ 根据邮箱获取用户 Args: db: 数据库会话 email: 邮箱 Returns: User: 用户对象或None """ result = await db.execute( select(User).where(User.email == email) ) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, skip: int = 0, limit: int = 20, keyword: Optional[str] = None, is_active: Optional[bool] = None ) -> Tuple[List[User], int]: """ 获取用户列表 Args: db: 数据库会话 skip: 跳过条数 limit: 返回条数 keyword: 搜索关键词 is_active: 是否激活 Returns: Tuple[List[User], int]: 用户列表和总数 """ conditions = [] if keyword: keyword_pattern = f"%{keyword}%" conditions.append( or_( User.username.ilike(keyword_pattern), User.full_name.ilike(keyword_pattern), User.phone.ilike(keyword_pattern), User.email.ilike(keyword_pattern) ) ) if is_active is not None: conditions.append(User.is_active == is_active) # 构建查询 query = select(User) if conditions: query = query.where(*conditions) query = query.order_by(User.id.desc()) # 获取总数 count_query = select(func.count(User.id)) if conditions: count_query = count_query.where(*conditions) count_result = await db.execute(count_query) total = count_result.scalar() # 分页查询 result = await db.execute(query.offset(skip).limit(limit)) users = result.scalars().all() return list(users), total async def create(self, db: AsyncSession, username: str, email: str, password: str, full_name: Optional[str] = None) -> User: """ 创建用户 Args: db: 数据库会话 username: 用户名 email: 邮箱 password: 密码 full_name: 全名 Returns: User: 创建的用户对象 """ db_obj = User( username=username, email=email, hashed_password=get_password_hash(password), full_name=full_name ) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def update_password( self, db: AsyncSession, user: User, new_password: str ) -> bool: """ 更新用户密码 Args: db: 数据库会话 user: 用户对象 new_password: 新密码 Returns: bool: 是否更新成功 """ user.hashed_password = get_password_hash(new_password) await db.commit() return True async def update_last_login(self, db: AsyncSession, user: User) -> bool: """ 更新用户最后登录时间 Args: db: 数据库会话 user: 用户对象 Returns: bool: 是否更新成功 """ user.last_login_at = datetime.utcnow() await db.commit() return True # 创建CRUD实例 user_crud = UserCRUD()