447 lines
12 KiB
Python
447 lines
12 KiB
Python
"""
|
|
消息通知CRUD操作
|
|
"""
|
|
from typing import Optional, List, Dict, Any
|
|
from datetime import datetime
|
|
from sqlalchemy import select, and_, or_, func, desc, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.models.notification import Notification, NotificationTemplate
|
|
|
|
|
|
class NotificationCRUD:
|
|
"""消息通知CRUD类"""
|
|
|
|
async def get(self, db: AsyncSession, notification_id: int) -> Optional[Notification]:
|
|
"""
|
|
根据ID获取消息通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
notification_id: 通知ID
|
|
|
|
Returns:
|
|
Notification对象或None
|
|
"""
|
|
result = await db.execute(
|
|
select(Notification).where(Notification.id == notification_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_multi(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
recipient_id: Optional[int] = None,
|
|
notification_type: Optional[str] = None,
|
|
priority: Optional[str] = None,
|
|
is_read: Optional[bool] = None,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
keyword: Optional[str] = None
|
|
) -> tuple[List[Notification], int]:
|
|
"""
|
|
获取消息通知列表
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
skip: 跳过条数
|
|
limit: 返回条数
|
|
recipient_id: 接收人ID
|
|
notification_type: 通知类型
|
|
priority: 优先级
|
|
is_read: 是否已读
|
|
start_time: 开始时间
|
|
end_time: 结束时间
|
|
keyword: 关键词
|
|
|
|
Returns:
|
|
(通知列表, 总数)
|
|
"""
|
|
# 构建查询条件
|
|
conditions = []
|
|
|
|
if recipient_id:
|
|
conditions.append(Notification.recipient_id == recipient_id)
|
|
|
|
if notification_type:
|
|
conditions.append(Notification.notification_type == notification_type)
|
|
|
|
if priority:
|
|
conditions.append(Notification.priority == priority)
|
|
|
|
if is_read is not None:
|
|
conditions.append(Notification.is_read == is_read)
|
|
|
|
if start_time:
|
|
conditions.append(Notification.created_at >= start_time)
|
|
|
|
if end_time:
|
|
conditions.append(Notification.created_at <= end_time)
|
|
|
|
if keyword:
|
|
conditions.append(
|
|
or_(
|
|
Notification.title.ilike(f"%{keyword}%"),
|
|
Notification.content.ilike(f"%{keyword}%")
|
|
)
|
|
)
|
|
|
|
# 查询总数
|
|
count_query = select(func.count(Notification.id))
|
|
if conditions:
|
|
count_query = count_query.where(and_(*conditions))
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
# 查询数据
|
|
query = select(Notification)
|
|
if conditions:
|
|
query = query.where(and_(*conditions))
|
|
query = query.order_by(
|
|
Notification.is_read.asc(), # 未读优先
|
|
desc(Notification.created_at) # 按时间倒序
|
|
)
|
|
query = query.offset(skip).limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
items = result.scalars().all()
|
|
|
|
return list(items), total
|
|
|
|
async def create(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
obj_in: Dict[str, Any]
|
|
) -> Notification:
|
|
"""
|
|
创建消息通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
obj_in: 创建数据
|
|
|
|
Returns:
|
|
Notification对象
|
|
"""
|
|
db_obj = Notification(**obj_in)
|
|
db.add(db_obj)
|
|
await db.flush()
|
|
await db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
async def batch_create(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
recipient_ids: List[int],
|
|
notification_data: Dict[str, Any]
|
|
) -> List[Notification]:
|
|
"""
|
|
批量创建消息通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
recipient_ids: 接收人ID列表
|
|
notification_data: 通知数据
|
|
|
|
Returns:
|
|
Notification对象列表
|
|
"""
|
|
notifications = []
|
|
for recipient_id in recipient_ids:
|
|
obj_data = notification_data.copy()
|
|
obj_data["recipient_id"] = recipient_id
|
|
db_obj = Notification(**obj_data)
|
|
db.add(db_obj)
|
|
notifications.append(db_obj)
|
|
|
|
await db.flush()
|
|
return notifications
|
|
|
|
async def update(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
db_obj: Notification,
|
|
obj_in: Dict[str, Any]
|
|
) -> Notification:
|
|
"""
|
|
更新消息通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
db_obj: 数据库对象
|
|
obj_in: 更新数据
|
|
|
|
Returns:
|
|
Notification对象
|
|
"""
|
|
for field, value in obj_in.items():
|
|
if hasattr(db_obj, field):
|
|
setattr(db_obj, field, value)
|
|
|
|
await db.flush()
|
|
await db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
async def mark_as_read(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
notification_id: int,
|
|
read_at: Optional[datetime] = None
|
|
) -> Optional[Notification]:
|
|
"""
|
|
标记为已读
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
notification_id: 通知ID
|
|
read_at: 已读时间
|
|
|
|
Returns:
|
|
Notification对象或None
|
|
"""
|
|
db_obj = await self.get(db, notification_id)
|
|
if not db_obj:
|
|
return None
|
|
|
|
if not db_obj.is_read:
|
|
db_obj.is_read = True
|
|
db_obj.read_at = read_at or datetime.utcnow()
|
|
await db.flush()
|
|
|
|
return db_obj
|
|
|
|
async def mark_all_as_read(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
recipient_id: int,
|
|
read_at: Optional[datetime] = None
|
|
) -> int:
|
|
"""
|
|
标记所有未读为已读
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
recipient_id: 接收人ID
|
|
read_at: 已读时间
|
|
|
|
Returns:
|
|
更新数量
|
|
"""
|
|
stmt = (
|
|
update(Notification)
|
|
.where(
|
|
and_(
|
|
Notification.recipient_id == recipient_id,
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
.values(
|
|
is_read=True,
|
|
read_at=read_at or datetime.utcnow()
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
await db.flush()
|
|
return result.rowcount
|
|
|
|
async def delete(self, db: AsyncSession, *, notification_id: int) -> Optional[Notification]:
|
|
"""
|
|
删除消息通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
notification_id: 通知ID
|
|
|
|
Returns:
|
|
删除的Notification对象或None
|
|
"""
|
|
obj = await self.get(db, notification_id)
|
|
if obj:
|
|
await db.delete(obj)
|
|
await db.flush()
|
|
return obj
|
|
|
|
async def batch_delete(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
notification_ids: List[int]
|
|
) -> int:
|
|
"""
|
|
批量删除通知
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
notification_ids: 通知ID列表
|
|
|
|
Returns:
|
|
删除数量
|
|
"""
|
|
from sqlalchemy import delete
|
|
|
|
stmt = delete(Notification).where(Notification.id.in_(notification_ids))
|
|
result = await db.execute(stmt)
|
|
await db.flush()
|
|
return result.rowcount
|
|
|
|
async def batch_mark_as_read(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
notification_ids: List[int],
|
|
read_at: Optional[datetime] = None,
|
|
recipient_id: Optional[int] = None
|
|
) -> int:
|
|
"""
|
|
批量标记为已读
|
|
"""
|
|
stmt = (
|
|
update(Notification)
|
|
.where(Notification.id.in_(notification_ids))
|
|
)
|
|
if recipient_id:
|
|
stmt = stmt.where(Notification.recipient_id == recipient_id)
|
|
stmt = stmt.values(is_read=True, read_at=read_at or datetime.utcnow())
|
|
result = await db.execute(stmt)
|
|
await db.flush()
|
|
return result.rowcount
|
|
|
|
async def batch_mark_as_unread(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
notification_ids: List[int],
|
|
recipient_id: Optional[int] = None
|
|
) -> int:
|
|
"""
|
|
批量标记为未读
|
|
"""
|
|
stmt = (
|
|
update(Notification)
|
|
.where(Notification.id.in_(notification_ids))
|
|
)
|
|
if recipient_id:
|
|
stmt = stmt.where(Notification.recipient_id == recipient_id)
|
|
stmt = stmt.values(is_read=False, read_at=None)
|
|
result = await db.execute(stmt)
|
|
await db.flush()
|
|
return result.rowcount
|
|
|
|
async def get_unread_count(
|
|
self,
|
|
db: AsyncSession,
|
|
recipient_id: int
|
|
) -> int:
|
|
"""
|
|
获取未读通知数量
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
recipient_id: 接收人ID
|
|
|
|
Returns:
|
|
未读数量
|
|
"""
|
|
result = await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
and_(
|
|
Notification.recipient_id == recipient_id,
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
async def get_statistics(
|
|
self,
|
|
db: AsyncSession,
|
|
recipient_id: int
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
获取通知统计信息
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
recipient_id: 接收人ID
|
|
|
|
Returns:
|
|
统计信息
|
|
"""
|
|
# 总数
|
|
total_result = await db.execute(
|
|
select(func.count(Notification.id)).where(Notification.recipient_id == recipient_id)
|
|
)
|
|
total_count = total_result.scalar() or 0
|
|
|
|
# 未读数
|
|
unread_result = await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
and_(
|
|
Notification.recipient_id == recipient_id,
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
)
|
|
unread_count = unread_result.scalar() or 0
|
|
|
|
# 已读数
|
|
read_count = total_count - unread_count
|
|
|
|
# 高优先级数
|
|
high_priority_result = await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
and_(
|
|
Notification.recipient_id == recipient_id,
|
|
Notification.priority.in_(["high", "urgent"]),
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
)
|
|
high_priority_count = high_priority_result.scalar() or 0
|
|
|
|
# 紧急通知数
|
|
urgent_result = await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
and_(
|
|
Notification.recipient_id == recipient_id,
|
|
Notification.priority == "urgent",
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
)
|
|
urgent_count = urgent_result.scalar() or 0
|
|
|
|
# 类型分布
|
|
type_result = await db.execute(
|
|
select(
|
|
Notification.notification_type,
|
|
func.count(Notification.id).label('count')
|
|
)
|
|
.where(Notification.recipient_id == recipient_id)
|
|
.group_by(Notification.notification_type)
|
|
)
|
|
type_distribution = [
|
|
{"type": row[0], "count": row[1]}
|
|
for row in type_result
|
|
]
|
|
|
|
return {
|
|
"total_count": total_count,
|
|
"unread_count": unread_count,
|
|
"read_count": read_count,
|
|
"high_priority_count": high_priority_count,
|
|
"urgent_count": urgent_count,
|
|
"type_distribution": type_distribution,
|
|
}
|
|
|
|
|
|
# 创建全局实例
|
|
notification_crud = NotificationCRUD()
|