Files
zcglxt/backend_new/app/crud/notification.py

404 lines
11 KiB
Python

"""
消息通知CRUD操作
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from sqlalchemy import select, and_, or_, func, desc, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.notification import Notification, NotificationTemplate
class NotificationCRUD:
"""消息通知CRUD类"""
async def get(self, db: AsyncSession, notification_id: int) -> Optional[Notification]:
"""
根据ID获取消息通知
Args:
db: 数据库会话
notification_id: 通知ID
Returns:
Notification对象或None
"""
result = await db.execute(
select(Notification).where(Notification.id == notification_id)
)
return result.scalar_one_or_none()
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
recipient_id: Optional[int] = None,
notification_type: Optional[str] = None,
priority: Optional[str] = None,
is_read: Optional[bool] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
keyword: Optional[str] = None
) -> tuple[List[Notification], int]:
"""
获取消息通知列表
Args:
db: 数据库会话
skip: 跳过条数
limit: 返回条数
recipient_id: 接收人ID
notification_type: 通知类型
priority: 优先级
is_read: 是否已读
start_time: 开始时间
end_time: 结束时间
keyword: 关键词
Returns:
(通知列表, 总数)
"""
# 构建查询条件
conditions = []
if recipient_id:
conditions.append(Notification.recipient_id == recipient_id)
if notification_type:
conditions.append(Notification.notification_type == notification_type)
if priority:
conditions.append(Notification.priority == priority)
if is_read is not None:
conditions.append(Notification.is_read == is_read)
if start_time:
conditions.append(Notification.created_at >= start_time)
if end_time:
conditions.append(Notification.created_at <= end_time)
if keyword:
conditions.append(
or_(
Notification.title.ilike(f"%{keyword}%"),
Notification.content.ilike(f"%{keyword}%")
)
)
# 查询总数
count_query = select(func.count(Notification.id))
if conditions:
count_query = count_query.where(and_(*conditions))
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 查询数据
query = select(Notification)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(
Notification.is_read.asc(), # 未读优先
desc(Notification.created_at) # 按时间倒序
)
query = query.offset(skip).limit(limit)
result = await db.execute(query)
items = result.scalars().all()
return list(items), total
async def create(
self,
db: AsyncSession,
*,
obj_in: Dict[str, Any]
) -> Notification:
"""
创建消息通知
Args:
db: 数据库会话
obj_in: 创建数据
Returns:
Notification对象
"""
db_obj = Notification(**obj_in)
db.add(db_obj)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def batch_create(
self,
db: AsyncSession,
*,
recipient_ids: List[int],
notification_data: Dict[str, Any]
) -> List[Notification]:
"""
批量创建消息通知
Args:
db: 数据库会话
recipient_ids: 接收人ID列表
notification_data: 通知数据
Returns:
Notification对象列表
"""
notifications = []
for recipient_id in recipient_ids:
obj_data = notification_data.copy()
obj_data["recipient_id"] = recipient_id
db_obj = Notification(**obj_data)
db.add(db_obj)
notifications.append(db_obj)
await db.flush()
return notifications
async def update(
self,
db: AsyncSession,
*,
db_obj: Notification,
obj_in: Dict[str, Any]
) -> Notification:
"""
更新消息通知
Args:
db: 数据库会话
db_obj: 数据库对象
obj_in: 更新数据
Returns:
Notification对象
"""
for field, value in obj_in.items():
if hasattr(db_obj, field):
setattr(db_obj, field, value)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def mark_as_read(
self,
db: AsyncSession,
*,
notification_id: int,
read_at: Optional[datetime] = None
) -> Optional[Notification]:
"""
标记为已读
Args:
db: 数据库会话
notification_id: 通知ID
read_at: 已读时间
Returns:
Notification对象或None
"""
db_obj = await self.get(db, notification_id)
if not db_obj:
return None
if not db_obj.is_read:
db_obj.is_read = True
db_obj.read_at = read_at or datetime.utcnow()
await db.flush()
return db_obj
async def mark_all_as_read(
self,
db: AsyncSession,
*,
recipient_id: int,
read_at: Optional[datetime] = None
) -> int:
"""
标记所有未读为已读
Args:
db: 数据库会话
recipient_id: 接收人ID
read_at: 已读时间
Returns:
更新数量
"""
stmt = (
update(Notification)
.where(
and_(
Notification.recipient_id == recipient_id,
Notification.is_read == False
)
)
.values(
is_read=True,
read_at=read_at or datetime.utcnow()
)
)
result = await db.execute(stmt)
await db.flush()
return result.rowcount
async def delete(self, db: AsyncSession, *, notification_id: int) -> Optional[Notification]:
"""
删除消息通知
Args:
db: 数据库会话
notification_id: 通知ID
Returns:
删除的Notification对象或None
"""
obj = await self.get(db, notification_id)
if obj:
await db.delete(obj)
await db.flush()
return obj
async def batch_delete(
self,
db: AsyncSession,
*,
notification_ids: List[int]
) -> int:
"""
批量删除通知
Args:
db: 数据库会话
notification_ids: 通知ID列表
Returns:
删除数量
"""
from sqlalchemy import delete
stmt = delete(Notification).where(Notification.id.in_(notification_ids))
result = await db.execute(stmt)
await db.flush()
return result.rowcount
async def get_unread_count(
self,
db: AsyncSession,
recipient_id: int
) -> int:
"""
获取未读通知数量
Args:
db: 数据库会话
recipient_id: 接收人ID
Returns:
未读数量
"""
result = await db.execute(
select(func.count(Notification.id)).where(
and_(
Notification.recipient_id == recipient_id,
Notification.is_read == False
)
)
)
return result.scalar() or 0
async def get_statistics(
self,
db: AsyncSession,
recipient_id: int
) -> Dict[str, Any]:
"""
获取通知统计信息
Args:
db: 数据库会话
recipient_id: 接收人ID
Returns:
统计信息
"""
# 总数
total_result = await db.execute(
select(func.count(Notification.id)).where(Notification.recipient_id == recipient_id)
)
total_count = total_result.scalar() or 0
# 未读数
unread_result = await db.execute(
select(func.count(Notification.id)).where(
and_(
Notification.recipient_id == recipient_id,
Notification.is_read == False
)
)
)
unread_count = unread_result.scalar() or 0
# 已读数
read_count = total_count - unread_count
# 高优先级数
high_priority_result = await db.execute(
select(func.count(Notification.id)).where(
and_(
Notification.recipient_id == recipient_id,
Notification.priority.in_(["high", "urgent"]),
Notification.is_read == False
)
)
)
high_priority_count = high_priority_result.scalar() or 0
# 紧急通知数
urgent_result = await db.execute(
select(func.count(Notification.id)).where(
and_(
Notification.recipient_id == recipient_id,
Notification.priority == "urgent",
Notification.is_read == False
)
)
)
urgent_count = urgent_result.scalar() or 0
# 类型分布
type_result = await db.execute(
select(
Notification.notification_type,
func.count(Notification.id).label('count')
)
.where(Notification.recipient_id == recipient_id)
.group_by(Notification.notification_type)
)
type_distribution = [
{"type": row[0], "count": row[1]}
for row in type_result
]
return {
"total_count": total_count,
"unread_count": unread_count,
"read_count": read_count,
"high_priority_count": high_priority_count,
"urgent_count": urgent_count,
"type_distribution": type_distribution,
}
# 创建全局实例
notification_crud = NotificationCRUD()