""" 消息通知CRUD操作 """ from typing import Optional, List, Dict, Any from datetime import datetime from sqlalchemy import select, and_, or_, func, desc, update from sqlalchemy.ext.asyncio import AsyncSession from app.models.notification import Notification, NotificationTemplate class NotificationCRUD: """消息通知CRUD类""" async def get(self, db: AsyncSession, notification_id: int) -> Optional[Notification]: """ 根据ID获取消息通知 Args: db: 数据库会话 notification_id: 通知ID Returns: Notification对象或None """ result = await db.execute( select(Notification).where(Notification.id == notification_id) ) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, *, skip: int = 0, limit: int = 100, recipient_id: Optional[int] = None, notification_type: Optional[str] = None, priority: Optional[str] = None, is_read: Optional[bool] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, keyword: Optional[str] = None ) -> tuple[List[Notification], int]: """ 获取消息通知列表 Args: db: 数据库会话 skip: 跳过条数 limit: 返回条数 recipient_id: 接收人ID notification_type: 通知类型 priority: 优先级 is_read: 是否已读 start_time: 开始时间 end_time: 结束时间 keyword: 关键词 Returns: (通知列表, 总数) """ # 构建查询条件 conditions = [] if recipient_id: conditions.append(Notification.recipient_id == recipient_id) if notification_type: conditions.append(Notification.notification_type == notification_type) if priority: conditions.append(Notification.priority == priority) if is_read is not None: conditions.append(Notification.is_read == is_read) if start_time: conditions.append(Notification.created_at >= start_time) if end_time: conditions.append(Notification.created_at <= end_time) if keyword: conditions.append( or_( Notification.title.ilike(f"%{keyword}%"), Notification.content.ilike(f"%{keyword}%") ) ) # 查询总数 count_query = select(func.count(Notification.id)) if conditions: count_query = count_query.where(and_(*conditions)) total_result = await db.execute(count_query) total = total_result.scalar() or 0 # 查询数据 query = select(Notification) if conditions: query = query.where(and_(*conditions)) query = query.order_by( Notification.is_read.asc(), # 未读优先 desc(Notification.created_at) # 按时间倒序 ) query = query.offset(skip).limit(limit) result = await db.execute(query) items = result.scalars().all() return list(items), total async def create( self, db: AsyncSession, *, obj_in: Dict[str, Any] ) -> Notification: """ 创建消息通知 Args: db: 数据库会话 obj_in: 创建数据 Returns: Notification对象 """ db_obj = Notification(**obj_in) db.add(db_obj) await db.flush() await db.refresh(db_obj) return db_obj async def batch_create( self, db: AsyncSession, *, recipient_ids: List[int], notification_data: Dict[str, Any] ) -> List[Notification]: """ 批量创建消息通知 Args: db: 数据库会话 recipient_ids: 接收人ID列表 notification_data: 通知数据 Returns: Notification对象列表 """ notifications = [] for recipient_id in recipient_ids: obj_data = notification_data.copy() obj_data["recipient_id"] = recipient_id db_obj = Notification(**obj_data) db.add(db_obj) notifications.append(db_obj) await db.flush() return notifications async def update( self, db: AsyncSession, *, db_obj: Notification, obj_in: Dict[str, Any] ) -> Notification: """ 更新消息通知 Args: db: 数据库会话 db_obj: 数据库对象 obj_in: 更新数据 Returns: Notification对象 """ for field, value in obj_in.items(): if hasattr(db_obj, field): setattr(db_obj, field, value) await db.flush() await db.refresh(db_obj) return db_obj async def mark_as_read( self, db: AsyncSession, *, notification_id: int, read_at: Optional[datetime] = None ) -> Optional[Notification]: """ 标记为已读 Args: db: 数据库会话 notification_id: 通知ID read_at: 已读时间 Returns: Notification对象或None """ db_obj = await self.get(db, notification_id) if not db_obj: return None if not db_obj.is_read: db_obj.is_read = True db_obj.read_at = read_at or datetime.utcnow() await db.flush() return db_obj async def mark_all_as_read( self, db: AsyncSession, *, recipient_id: int, read_at: Optional[datetime] = None ) -> int: """ 标记所有未读为已读 Args: db: 数据库会话 recipient_id: 接收人ID read_at: 已读时间 Returns: 更新数量 """ stmt = ( update(Notification) .where( and_( Notification.recipient_id == recipient_id, Notification.is_read == False ) ) .values( is_read=True, read_at=read_at or datetime.utcnow() ) ) result = await db.execute(stmt) await db.flush() return result.rowcount async def delete(self, db: AsyncSession, *, notification_id: int) -> Optional[Notification]: """ 删除消息通知 Args: db: 数据库会话 notification_id: 通知ID Returns: 删除的Notification对象或None """ obj = await self.get(db, notification_id) if obj: await db.delete(obj) await db.flush() return obj async def batch_delete( self, db: AsyncSession, *, notification_ids: List[int] ) -> int: """ 批量删除通知 Args: db: 数据库会话 notification_ids: 通知ID列表 Returns: 删除数量 """ from sqlalchemy import delete stmt = delete(Notification).where(Notification.id.in_(notification_ids)) result = await db.execute(stmt) await db.flush() return result.rowcount async def get_unread_count( self, db: AsyncSession, recipient_id: int ) -> int: """ 获取未读通知数量 Args: db: 数据库会话 recipient_id: 接收人ID Returns: 未读数量 """ result = await db.execute( select(func.count(Notification.id)).where( and_( Notification.recipient_id == recipient_id, Notification.is_read == False ) ) ) return result.scalar() or 0 async def get_statistics( self, db: AsyncSession, recipient_id: int ) -> Dict[str, Any]: """ 获取通知统计信息 Args: db: 数据库会话 recipient_id: 接收人ID Returns: 统计信息 """ # 总数 total_result = await db.execute( select(func.count(Notification.id)).where(Notification.recipient_id == recipient_id) ) total_count = total_result.scalar() or 0 # 未读数 unread_result = await db.execute( select(func.count(Notification.id)).where( and_( Notification.recipient_id == recipient_id, Notification.is_read == False ) ) ) unread_count = unread_result.scalar() or 0 # 已读数 read_count = total_count - unread_count # 高优先级数 high_priority_result = await db.execute( select(func.count(Notification.id)).where( and_( Notification.recipient_id == recipient_id, Notification.priority.in_(["high", "urgent"]), Notification.is_read == False ) ) ) high_priority_count = high_priority_result.scalar() or 0 # 紧急通知数 urgent_result = await db.execute( select(func.count(Notification.id)).where( and_( Notification.recipient_id == recipient_id, Notification.priority == "urgent", Notification.is_read == False ) ) ) urgent_count = urgent_result.scalar() or 0 # 类型分布 type_result = await db.execute( select( Notification.notification_type, func.count(Notification.id).label('count') ) .where(Notification.recipient_id == recipient_id) .group_by(Notification.notification_type) ) type_distribution = [ {"type": row[0], "count": row[1]} for row in type_result ] return { "total_count": total_count, "unread_count": unread_count, "read_count": read_count, "high_priority_count": high_priority_count, "urgent_count": urgent_count, "type_distribution": type_distribution, } # 创建全局实例 notification_crud = NotificationCRUD()