#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 配置管理模块 集中管理所有配置项,支持环境变量 """ import os from datetime import timedelta from pathlib import Path from urllib.parse import urlsplit, urlunsplit # 尝试加载.env文件(如果存在) # Bug fix: 添加警告日志,避免静默失败 try: from dotenv import load_dotenv env_path = Path(__file__).parent / '.env' if env_path.exists(): load_dotenv(dotenv_path=env_path) print(f"✓ 已加载环境变量文件: {env_path}") except ImportError: # python-dotenv未安装,记录警告 import sys print("⚠ 警告: python-dotenv未安装,将不会加载.env文件。如需使用.env文件,请运行: pip install python-dotenv", file=sys.stderr) # 常量定义 SECRET_KEY_FILE = 'data/secret_key.txt' def get_secret_key(): """获取SECRET_KEY(优先环境变量)""" # 优先从环境变量读取 secret_key = os.environ.get('SECRET_KEY') if secret_key: return secret_key # 从文件读取 if os.path.exists(SECRET_KEY_FILE): with open(SECRET_KEY_FILE, 'r') as f: return f.read().strip() # 生成新的 new_key = os.urandom(24).hex() os.makedirs('data', exist_ok=True) with open(SECRET_KEY_FILE, 'w') as f: f.write(new_key) print(f"✓ 已生成新的SECRET_KEY并保存到 {SECRET_KEY_FILE}") return new_key def _derive_base_url_from_full_url(url: str, fallback: str) -> str: """从完整 URL 推导出 base_url(scheme://netloc)。""" try: parsed = urlsplit(str(url or "").strip()) if parsed.scheme and parsed.netloc: return f"{parsed.scheme}://{parsed.netloc}" except Exception: pass return fallback def _derive_sibling_url(full_url: str, filename: str, fallback: str) -> str: """把 full_url 的最后路径段替换为 filename(忽略 query/fragment)。""" try: parsed = urlsplit(str(full_url or "").strip()) if not parsed.scheme or not parsed.netloc: return fallback path = parsed.path or "/" if path.endswith("/"): new_path = path + filename else: new_path = path.rsplit("/", 1)[0] + "/" + filename return urlunsplit((parsed.scheme, parsed.netloc, new_path, "", "")) except Exception: return fallback class Config: """应用配置基类""" # ==================== Flask核心配置 ==================== SECRET_KEY = get_secret_key() # ==================== 会话安全配置 ==================== # 安全修复: 根据环境自动选择安全配置 # 生产环境(FLASK_ENV=production)时自动启用更严格的安全设置 _is_production = os.environ.get('FLASK_ENV', 'production') == 'production' _force_secure = os.environ.get('SESSION_COOKIE_SECURE', '').lower() == 'true' SESSION_COOKIE_SECURE = _force_secure or (_is_production and os.environ.get('HTTPS_ENABLED', 'false').lower() == 'true') SESSION_COOKIE_HTTPONLY = True # 防止XSS攻击 # SameSite配置:HTTPS环境使用None,HTTP环境使用Lax SESSION_COOKIE_SAMESITE = 'None' if SESSION_COOKIE_SECURE else 'Lax' # 自定义cookie名称,避免与其他应用冲突 SESSION_COOKIE_NAME = os.environ.get('SESSION_COOKIE_NAME', 'zsglpt_session') # Cookie路径,确保整个应用都能访问 SESSION_COOKIE_PATH = '/' PERMANENT_SESSION_LIFETIME = timedelta(hours=int(os.environ.get('SESSION_LIFETIME_HOURS', '24'))) # 安全警告检查 @classmethod def check_security_warnings(cls): """检查安全配置,输出警告""" import sys warnings = [] env = os.environ.get('FLASK_ENV', 'production') if env == 'production': if not cls.SESSION_COOKIE_SECURE: warnings.append("SESSION_COOKIE_SECURE=False: 生产环境建议启用HTTPS并设置SESSION_COOKIE_SECURE=true") if warnings: print("\n⚠ 安全配置警告:", file=sys.stderr) for w in warnings: print(f" - {w}", file=sys.stderr) print("", file=sys.stderr) # ==================== 数据库配置 ==================== DB_FILE = os.environ.get('DB_FILE', 'data/app_data.db') DB_POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '5')) # ==================== 浏览器配置 ==================== SCREENSHOTS_DIR = os.environ.get('SCREENSHOTS_DIR', '截图') COOKIES_DIR = os.environ.get('COOKIES_DIR', 'data/cookies') # ==================== 并发控制配置 ==================== MAX_CONCURRENT_GLOBAL = int(os.environ.get('MAX_CONCURRENT_GLOBAL', '2')) MAX_CONCURRENT_PER_ACCOUNT = int(os.environ.get('MAX_CONCURRENT_PER_ACCOUNT', '1')) # ==================== 日志缓存配置 ==================== MAX_LOGS_PER_USER = int(os.environ.get('MAX_LOGS_PER_USER', '100')) MAX_TOTAL_LOGS = int(os.environ.get('MAX_TOTAL_LOGS', '1000')) # ==================== 内存/缓存清理配置 ==================== USER_ACCOUNTS_EXPIRE_SECONDS = int(os.environ.get('USER_ACCOUNTS_EXPIRE_SECONDS', '3600')) BATCH_TASK_EXPIRE_SECONDS = int(os.environ.get('BATCH_TASK_EXPIRE_SECONDS', '21600')) # 默认6小时 PENDING_RANDOM_EXPIRE_SECONDS = int(os.environ.get('PENDING_RANDOM_EXPIRE_SECONDS', '7200')) # 默认2小时 # ==================== 验证码配置 ==================== MAX_CAPTCHA_ATTEMPTS = int(os.environ.get('MAX_CAPTCHA_ATTEMPTS', '5')) CAPTCHA_EXPIRE_SECONDS = int(os.environ.get('CAPTCHA_EXPIRE_SECONDS', '300')) # ==================== IP限流配置 ==================== MAX_IP_ATTEMPTS_PER_HOUR = int(os.environ.get('MAX_IP_ATTEMPTS_PER_HOUR', '10')) IP_LOCK_DURATION = int(os.environ.get('IP_LOCK_DURATION', '3600')) # 秒 IP_RATE_LIMIT_LOGIN_MAX = int(os.environ.get('IP_RATE_LIMIT_LOGIN_MAX', '20')) IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS = int(os.environ.get('IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS', '60')) IP_RATE_LIMIT_REGISTER_MAX = int(os.environ.get('IP_RATE_LIMIT_REGISTER_MAX', '10')) IP_RATE_LIMIT_REGISTER_WINDOW_SECONDS = int(os.environ.get('IP_RATE_LIMIT_REGISTER_WINDOW_SECONDS', '3600')) IP_RATE_LIMIT_EMAIL_MAX = int(os.environ.get('IP_RATE_LIMIT_EMAIL_MAX', '20')) IP_RATE_LIMIT_EMAIL_WINDOW_SECONDS = int(os.environ.get('IP_RATE_LIMIT_EMAIL_WINDOW_SECONDS', '3600')) # ==================== 超时配置 ==================== PAGE_LOAD_TIMEOUT = int(os.environ.get('PAGE_LOAD_TIMEOUT', '60000')) # 毫秒 DEFAULT_TIMEOUT = int(os.environ.get('DEFAULT_TIMEOUT', '60000')) # 毫秒 # ==================== 知识管理平台配置 ==================== ZSGL_LOGIN_URL = os.environ.get('ZSGL_LOGIN_URL', 'https://postoa.aidunsoft.com/admin/login.aspx') ZSGL_INDEX_URL_PATTERN = os.environ.get('ZSGL_INDEX_URL_PATTERN', 'index.aspx') ZSGL_BASE_URL = os.environ.get('ZSGL_BASE_URL') or _derive_base_url_from_full_url(ZSGL_LOGIN_URL, 'https://postoa.aidunsoft.com') ZSGL_INDEX_URL = os.environ.get('ZSGL_INDEX_URL') or _derive_sibling_url( ZSGL_LOGIN_URL, ZSGL_INDEX_URL_PATTERN, f"{ZSGL_BASE_URL}/admin/{ZSGL_INDEX_URL_PATTERN}", ) MAX_CONCURRENT_CONTEXTS = int(os.environ.get('MAX_CONCURRENT_CONTEXTS', '100')) # ==================== 服务器配置 ==================== SERVER_HOST = os.environ.get('SERVER_HOST', '0.0.0.0') SERVER_PORT = int(os.environ.get('SERVER_PORT', '51233')) # ==================== SocketIO配置 ==================== SOCKETIO_CORS_ALLOWED_ORIGINS = os.environ.get('SOCKETIO_CORS_ALLOWED_ORIGINS', '*') # ==================== 网站基础URL配置 ==================== # 用于生成邮件中的验证链接等 BASE_URL = os.environ.get('BASE_URL', 'http://localhost:51233') # ==================== 日志配置 ==================== # 安全修复: 生产环境默认使用INFO级别,避免泄露敏感调试信息 LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO') LOG_FILE = os.environ.get('LOG_FILE', 'logs/app.log') LOG_MAX_BYTES = int(os.environ.get('LOG_MAX_BYTES', '10485760')) # 10MB LOG_BACKUP_COUNT = int(os.environ.get('LOG_BACKUP_COUNT', '5')) # ==================== 安全配置 ==================== DEBUG = os.environ.get('FLASK_DEBUG', 'False').lower() == 'true' ALLOWED_SCREENSHOT_EXTENSIONS = {'.png', '.jpg', '.jpeg'} MAX_SCREENSHOT_SIZE = int(os.environ.get('MAX_SCREENSHOT_SIZE', '10485760')) # 10MB LOGIN_CAPTCHA_AFTER_FAILURES = int(os.environ.get('LOGIN_CAPTCHA_AFTER_FAILURES', '3')) LOGIN_CAPTCHA_WINDOW_SECONDS = int(os.environ.get('LOGIN_CAPTCHA_WINDOW_SECONDS', '900')) LOGIN_RATE_LIMIT_WINDOW_SECONDS = int(os.environ.get('LOGIN_RATE_LIMIT_WINDOW_SECONDS', '900')) LOGIN_IP_MAX_ATTEMPTS = int(os.environ.get('LOGIN_IP_MAX_ATTEMPTS', '60')) LOGIN_USERNAME_MAX_ATTEMPTS = int(os.environ.get('LOGIN_USERNAME_MAX_ATTEMPTS', '30')) LOGIN_IP_USERNAME_MAX_ATTEMPTS = int(os.environ.get('LOGIN_IP_USERNAME_MAX_ATTEMPTS', '12')) LOGIN_FAIL_DELAY_BASE_MS = int(os.environ.get('LOGIN_FAIL_DELAY_BASE_MS', '200')) LOGIN_FAIL_DELAY_MAX_MS = int(os.environ.get('LOGIN_FAIL_DELAY_MAX_MS', '1200')) LOGIN_ACCOUNT_LOCK_FAILURES = int(os.environ.get('LOGIN_ACCOUNT_LOCK_FAILURES', '6')) LOGIN_ACCOUNT_LOCK_WINDOW_SECONDS = int(os.environ.get('LOGIN_ACCOUNT_LOCK_WINDOW_SECONDS', '900')) LOGIN_ACCOUNT_LOCK_SECONDS = int(os.environ.get('LOGIN_ACCOUNT_LOCK_SECONDS', '600')) LOGIN_SCAN_UNIQUE_USERNAME_THRESHOLD = int(os.environ.get('LOGIN_SCAN_UNIQUE_USERNAME_THRESHOLD', '8')) LOGIN_SCAN_WINDOW_SECONDS = int(os.environ.get('LOGIN_SCAN_WINDOW_SECONDS', '600')) LOGIN_SCAN_COOLDOWN_SECONDS = int(os.environ.get('LOGIN_SCAN_COOLDOWN_SECONDS', '600')) EMAIL_RATE_LIMIT_MAX = int(os.environ.get('EMAIL_RATE_LIMIT_MAX', '6')) EMAIL_RATE_LIMIT_WINDOW_SECONDS = int(os.environ.get('EMAIL_RATE_LIMIT_WINDOW_SECONDS', '3600')) LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true' LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600')) ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600')) @classmethod def validate(cls): """验证配置的有效性""" errors = [] # 验证SECRET_KEY if not cls.SECRET_KEY or len(cls.SECRET_KEY) < 32: errors.append("SECRET_KEY长度必须至少32个字符") # 验证并发配置 if cls.MAX_CONCURRENT_GLOBAL < 1: errors.append("MAX_CONCURRENT_GLOBAL必须大于0") if cls.MAX_CONCURRENT_PER_ACCOUNT < 1: errors.append("MAX_CONCURRENT_PER_ACCOUNT必须大于0") # 验证数据库配置 if not cls.DB_FILE: errors.append("DB_FILE不能为空") if cls.DB_POOL_SIZE < 1: errors.append("DB_POOL_SIZE必须大于0") # 验证日志配置 if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}") return errors @classmethod def print_config(cls): """打印当前配置(隐藏敏感信息)""" print("=" * 60) print("应用配置") print("=" * 60) print(f"DEBUG模式: {cls.DEBUG}") print(f"SECRET_KEY: {'*' * 20} (长度: {len(cls.SECRET_KEY)})") print(f"会话超时: {cls.PERMANENT_SESSION_LIFETIME}") print(f"Cookie安全: HTTPS={cls.SESSION_COOKIE_SECURE}, HttpOnly={cls.SESSION_COOKIE_HTTPONLY}") print(f"数据库文件: {cls.DB_FILE}") print(f"数据库连接池: {cls.DB_POOL_SIZE}") print(f"并发配置: 全局={cls.MAX_CONCURRENT_GLOBAL}, 单账号={cls.MAX_CONCURRENT_PER_ACCOUNT}") print(f"日志级别: {cls.LOG_LEVEL}") print(f"日志文件: {cls.LOG_FILE}") print(f"截图目录: {cls.SCREENSHOTS_DIR}") print("=" * 60) class DevelopmentConfig(Config): """开发环境配置""" DEBUG = True # 不覆盖SESSION_COOKIE_SECURE,使用父类的环境变量配置 class ProductionConfig(Config): """生产环境配置""" DEBUG = False # 不覆盖SESSION_COOKIE_SECURE,使用父类的环境变量配置 # 如需HTTPS,请在环境变量中设置 SESSION_COOKIE_SECURE=true class TestingConfig(Config): """测试环境配置""" DEBUG = True TESTING = True DB_FILE = 'data/test_app_data.db' # 根据环境变量选择配置 config_map = { 'development': DevelopmentConfig, 'production': ProductionConfig, 'testing': TestingConfig, } def get_config(): """获取当前环境的配置""" env = os.environ.get('FLASK_ENV', 'production') return config_map.get(env, ProductionConfig) if __name__ == '__main__': # 配置验证测试 config = get_config() errors = config.validate() if errors: print("配置验证失败:") for error in errors: print(f" ✗ {error}") else: print("✓ 配置验证通过") config.print_config()