Files
zcglxt/backend_new/app/crud/organization.py

352 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
机构网点CRUD操作
"""
from typing import List, Optional, Tuple
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import Session
from app.models.organization import Organization
from app.schemas.organization import OrganizationCreate, OrganizationUpdate
class OrganizationCRUD:
"""机构网点CRUD操作类"""
def get(self, db: Session, id: int) -> Optional[Organization]:
"""
根据ID获取机构
Args:
db: 数据库会话
id: 机构ID
Returns:
Organization对象或None
"""
return db.query(Organization).filter(
and_(
Organization.id == id,
Organization.deleted_at.is_(None)
)
).first()
def get_by_code(self, db: Session, code: str) -> Optional[Organization]:
"""
根据代码获取机构
Args:
db: 数据库会话
code: 机构代码
Returns:
Organization对象或None
"""
return db.query(Organization).filter(
and_(
Organization.org_code == code,
Organization.deleted_at.is_(None)
)
).first()
def get_multi(
self,
db: Session,
skip: int = 0,
limit: int = 20,
org_type: Optional[str] = None,
status: Optional[str] = None,
keyword: Optional[str] = None
) -> Tuple[List[Organization], int]:
"""
获取机构列表
Args:
db: 数据库会话
skip: 跳过条数
limit: 返回条数
org_type: 机构类型筛选
status: 状态筛选
keyword: 搜索关键词
Returns:
(机构列表, 总数)
"""
query = db.query(Organization).filter(Organization.deleted_at.is_(None))
# 筛选条件
if org_type:
query = query.filter(Organization.org_type == org_type)
if status:
query = query.filter(Organization.status == status)
if keyword:
query = query.filter(
or_(
Organization.org_code.ilike(f"%{keyword}%"),
Organization.org_name.ilike(f"%{keyword}%")
)
)
# 排序
query = query.order_by(Organization.tree_level.asc(), Organization.sort_order.asc(), Organization.id.asc())
# 总数
total = query.count()
# 分页
items = query.offset(skip).limit(limit).all()
return items, total
def get_tree(self, db: Session, status: Optional[str] = None) -> List[Organization]:
"""
获取机构树
Args:
db: 数据库会话
status: 状态筛选
Returns:
机构树列表
"""
query = db.query(Organization).filter(Organization.deleted_at.is_(None))
if status:
query = query.filter(Organization.status == status)
# 获取所有机构
all_orgs = query.order_by(Organization.tree_level.asc(), Organization.sort_order.asc()).all()
# 构建树形结构
org_map = {org.id: org for org in all_orgs}
tree = []
for org in all_orgs:
# 清空children列表
org.children = []
if org.parent_id is None:
# 根节点
tree.append(org)
else:
# 添加到父节点的children
parent = org_map.get(org.parent_id)
if parent:
if not hasattr(parent, 'children'):
parent.children = []
parent.children.append(org)
return tree
def get_children(self, db: Session, parent_id: int) -> List[Organization]:
"""
获取子机构列表(直接子节点)
Args:
db: 数据库会话
parent_id: 父机构ID
Returns:
子机构列表
"""
return db.query(Organization).filter(
and_(
Organization.parent_id == parent_id,
Organization.deleted_at.is_(None)
)
).order_by(Organization.sort_order.asc(), Organization.id.asc()).all()
def get_all_children(self, db: Session, parent_id: int) -> List[Organization]:
"""
递归获取所有子机构(包括子节点的子节点)
Args:
db: 数据库会话
parent_id: 父机构ID
Returns:
所有子机构列表
"""
# 获取父节点的tree_path
parent = self.get(db, parent_id)
if not parent:
return []
# 构建查询路径
if parent.tree_path:
search_path = f"{parent.tree_path}{parent.id}/"
else:
search_path = f"/{parent.id}/"
# 查询所有以该路径开头的机构
return db.query(Organization).filter(
and_(
Organization.tree_path.like(f"{search_path}%"),
Organization.deleted_at.is_(None)
)
).order_by(Organization.tree_level.asc(), Organization.sort_order.asc()).all()
def get_parents(self, db: Session, child_id: int) -> List[Organization]:
"""
递归获取所有父机构(从根到直接父节点)
Args:
db: 数据库会话
child_id: 子机构ID
Returns:
所有父机构列表(从根到父)
"""
child = self.get(db, child_id)
if not child or not child.tree_path:
return []
# 解析tree_path提取所有ID
path_ids = [int(id) for id in child.tree_path.split("/") if id]
if not path_ids:
return []
# 查询所有父机构
return db.query(Organization).filter(
and_(
Organization.id.in_(path_ids),
Organization.deleted_at.is_(None)
)
).order_by(Organization.tree_level.asc()).all()
def create(
self,
db: Session,
obj_in: OrganizationCreate,
creator_id: Optional[int] = None
) -> Organization:
"""
创建机构
Args:
db: 数据库会话
obj_in: 创建数据
creator_id: 创建人ID
Returns:
创建的Organization对象
"""
# 检查代码是否已存在
if self.get_by_code(db, obj_in.org_code):
raise ValueError(f"机构代码 '{obj_in.org_code}' 已存在")
# 计算tree_path和tree_level
tree_path = None
tree_level = 0
if obj_in.parent_id:
parent = self.get(db, obj_in.parent_id)
if not parent:
raise ValueError(f"父机构ID {obj_in.parent_id} 不存在")
# 构建tree_path
if parent.tree_path:
tree_path = f"{parent.tree_path}{parent.id}/"
else:
tree_path = f"/{parent.id}/"
tree_level = parent.tree_level + 1
db_obj = Organization(
**obj_in.model_dump(),
tree_path=tree_path,
tree_level=tree_level,
created_by=creator_id
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
db_obj: Organization,
obj_in: OrganizationUpdate,
updater_id: Optional[int] = None
) -> Organization:
"""
更新机构
Args:
db: 数据库会话
db_obj: 数据库对象
obj_in: 更新数据
updater_id: 更新人ID
Returns:
更新后的Organization对象
"""
obj_data = obj_in.model_dump(exclude_unset=True)
# 如果更新了parent_id需要重新计算tree_path和tree_level
if "parent_id" in obj_data:
new_parent_id = obj_data["parent_id"]
old_parent_id = db_obj.parent_id
if new_parent_id != old_parent_id:
# 重新计算当前节点的路径
if new_parent_id:
new_parent = self.get(db, new_parent_id)
if not new_parent:
raise ValueError(f"父机构ID {new_parent_id} 不存在")
if new_parent.tree_path:
db_obj.tree_path = f"{new_parent.tree_path}{new_parent.id}/"
else:
db_obj.tree_path = f"/{new_parent.id}/"
db_obj.tree_level = new_parent.tree_level + 1
else:
# 变为根节点
db_obj.tree_path = None
db_obj.tree_level = 0
# TODO: 需要递归更新所有子节点的tree_path和tree_level
# 这里需要批量更新,暂时跳过
for field, value in obj_data.items():
if field != "parent_id": # parent_id已经处理
setattr(db_obj, field, value)
db_obj.updated_by = updater_id
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def delete(self, db: Session, id: int, deleter_id: Optional[int] = None) -> bool:
"""
删除机构(软删除)
Args:
db: 数据库会话
id: 机构ID
deleter_id: 删除人ID
Returns:
是否删除成功
"""
obj = self.get(db, id)
if not obj:
return False
# 检查是否有子机构
children = self.get_children(db, id)
if children:
raise ValueError("该机构下存在子机构,无法删除")
obj.deleted_at = func.now()
obj.deleted_by = deleter_id
db.add(obj)
db.commit()
return True
# 创建全局实例
organization = OrganizationCRUD()