""" 数据库会话管理 """ from typing import AsyncGenerator from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from app.core.config import settings from app.db.base import Base # 创建异步引擎 engine = create_async_engine( settings.DATABASE_URL, echo=settings.DATABASE_ECHO, pool_pre_ping=True, pool_size=50, # 从20增加到50,提高并发性能 max_overflow=10, # 从0增加到10,允许峰值时的额外连接 ) # 创建同步引擎(用于遗留同步查询) def _get_sync_database_url() -> str: url = settings.DATABASE_URL if url.startswith("postgresql+asyncpg://"): return url.replace("postgresql+asyncpg://", "postgresql+psycopg2://", 1) if "+asyncpg" in url: return url.replace("+asyncpg", "+psycopg2") return url sync_engine = create_engine( _get_sync_database_url(), echo=settings.DATABASE_ECHO, pool_pre_ping=True, pool_size=50, max_overflow=10, ) # 创建异步会话工厂 async_session_maker = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) # 创建同步会话工厂 sync_session_maker = sessionmaker( bind=sync_engine, autocommit=False, autoflush=False, ) async def get_db() -> AsyncGenerator[AsyncSession, None]: """ 获取数据库会话 Yields: AsyncSession: 数据库会话 """ async with async_session_maker() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() async def init_db() -> None: """ 初始化数据库(创建所有表) 注意:生产环境应使用Alembic迁移 """ async with engine.begin() as conn: # 导入所有模型以确保它们被注册 from app.models import user, asset, device_type, organization # 创建所有表 await conn.run_sync(Base.metadata.create_all) async def close_db() -> None: """关闭数据库连接""" await engine.dispose() sync_engine.dispose() __all__ = [ "engine", "sync_engine", "async_session_maker", "sync_session_maker", "get_db", "init_db", "close_db", ]