#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 安全工具模块 提供各种安全相关的功能 """ import os import re import time import hashlib import secrets import ipaddress import socket from pathlib import Path from typing import Optional from functools import wraps from urllib.parse import urlparse from flask import request, jsonify, session from collections import defaultdict import threading # ==================== 文件路径安全 ==================== def is_safe_path(basedir, path, follow_symlinks=True): """ 检查路径是否安全(防止路径遍历攻击) Args: basedir: 基础目录 path: 要检查的路径 follow_symlinks: 是否跟随符号链接 Returns: bool: 路径是否安全 """ # 检查路径中是否包含危险字符 if '..' in path or path.startswith('/') or path.startswith('\\'): return False # 解析路径 if follow_symlinks: matchpath = os.path.realpath(os.path.join(basedir, path)) else: matchpath = os.path.abspath(os.path.join(basedir, path)) # 检查是否在基础目录内 return matchpath.startswith(os.path.abspath(basedir)) def sanitize_filename(filename): """ 清理文件名,移除危险字符 Args: filename: 原始文件名 Returns: str: 清理后的文件名 """ # 移除路径分隔符 filename = filename.replace('/', '_').replace('\\', '_') # 只保留安全字符 filename = re.sub(r'[^a-zA-Z0-9._-]', '_', filename) # 限制长度 if len(filename) > 255: name, ext = os.path.splitext(filename) filename = name[:255-len(ext)] + ext return filename # ==================== IP限流和黑名单 ==================== class IPRateLimiter: """IP访问频率限制器""" def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600): """ 初始化限流器 Args: max_attempts: 时间窗口内的最大尝试次数 window_seconds: 时间窗口大小(秒) lock_duration: 锁定时长(秒) """ self.max_attempts = max_attempts self.window_seconds = window_seconds self.lock_duration = lock_duration # IP访问记录: {ip: [(timestamp, success), ...]} self._attempts = defaultdict(list) # IP锁定记录: {ip: lock_until_timestamp} self._locked = {} self._lock = threading.Lock() def is_locked(self, ip_address): """ 检查IP是否被锁定 Args: ip_address: IP地址 Returns: bool: 是否被锁定 """ with self._lock: if ip_address in self._locked: if time.time() < self._locked[ip_address]: return True else: # 锁定已过期,移除 del self._locked[ip_address] return False def record_attempt(self, ip_address, success=True): """ 记录访问尝试 Args: ip_address: IP地址 success: 是否成功 Returns: bool: 是否应该锁定该IP """ with self._lock: now = time.time() # 清理过期记录 cutoff_time = now - self.window_seconds self._attempts[ip_address] = [ (ts, succ) for ts, succ in self._attempts[ip_address] if ts > cutoff_time ] # 记录本次尝试 self._attempts[ip_address].append((now, success)) # 检查失败次数 failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ) if failed_attempts >= self.max_attempts: # 锁定IP self._locked[ip_address] = now + self.lock_duration return True return False def get_remaining_attempts(self, ip_address): """ 获取剩余尝试次数 Args: ip_address: IP地址 Returns: int: 剩余尝试次数 """ with self._lock: now = time.time() cutoff_time = now - self.window_seconds # 清理过期记录 self._attempts[ip_address] = [ (ts, succ) for ts, succ in self._attempts[ip_address] if ts > cutoff_time ] failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ) return max(0, self.max_attempts - failed_attempts) def cleanup(self): """清理过期数据""" with self._lock: now = time.time() # 清理过期的尝试记录 cutoff_time = now - self.window_seconds for ip in list(self._attempts.keys()): self._attempts[ip] = [ (ts, succ) for ts, succ in self._attempts[ip] if ts > cutoff_time ] if not self._attempts[ip]: del self._attempts[ip] # 清理过期的锁定 for ip in list(self._locked.keys()): if now >= self._locked[ip]: del self._locked[ip] # 全局IP限流器实例 ip_rate_limiter = IPRateLimiter() def require_ip_not_locked(f): """装饰器:检查IP是否被锁定""" @wraps(f) def decorated_function(*args, **kwargs): ip_address = get_rate_limit_ip() # P0 / O-01:统一使用 services.state 的线程安全限流状态 try: from services.state import check_ip_rate_limit, safe_get_ip_lock_until allowed, error_msg = check_ip_rate_limit(ip_address) if not allowed: return ( jsonify( { "error": error_msg or "由于多次失败尝试,您的IP已被临时锁定", "locked_until": safe_get_ip_lock_until(ip_address), } ), 429, ) except Exception: # 兜底:沿用旧实现(避免极端情况下阻断业务) if ip_rate_limiter.is_locked(ip_address): return ( jsonify( { "error": "由于多次失败尝试,您的IP已被临时锁定", "locked_until": ip_rate_limiter._locked.get(ip_address, 0), } ), 429, ) return f(*args, **kwargs) return decorated_function # ==================== 输入验证 ==================== def validate_username(username): """ 验证用户名格式 Args: username: 用户名 Returns: tuple: (is_valid, error_message) """ if not username: return False, "用户名不能为空" if len(username) < 3: return False, "用户名长度不能少于3个字符" if len(username) > 50: return False, "用户名长度不能超过50个字符" # 只允许字母、数字、下划线、中文 if not re.match(r'^[\w\u4e00-\u9fa5]+$', username): return False, "用户名只能包含字母、数字、下划线和中文字符" # Bug fix: 过滤零宽字符和其他不可见字符 # 检查是否包含不可见/控制字符 import unicodedata for char in username: category = unicodedata.category(char) # Cf = 格式字符 (包括零宽字符), Cc = 控制字符 if category in ('Cf', 'Cc'): return False, "用户名不能包含不可见字符" return True, None def validate_password(password, require_complexity=True): """ 验证密码强度 安全修复:增强密码强度要求 Args: password: 密码 require_complexity: 是否要求复杂度(默认True) Returns: tuple: (is_valid, error_message) """ if not password: return False, "密码不能为空" if len(password) < 8: # 安全修复:最少8位 return False, "密码长度不能少于8个字符" if len(password) > 128: return False, "密码长度不能超过128个字符" # 安全修复:启用密码复杂度要求 if require_complexity: has_letter = bool(re.search(r'[a-zA-Z]', password)) has_digit = bool(re.search(r'\d', password)) if not (has_letter and has_digit): return False, "密码必须包含字母和数字" return True, None def validate_email(email): """ 验证邮箱格式 Args: email: 邮箱地址 Returns: tuple: (is_valid, error_message) """ if not email: return True, None # 邮箱可选 # 简单的邮箱正则 pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' if not re.match(pattern, email): return False, "邮箱格式不正确" if len(email) > 255: return False, "邮箱长度不能超过255个字符" return True, None # ==================== 会话安全 ==================== def generate_session_token(): """生成安全的会话令牌""" return secrets.token_urlsafe(32) def hash_token(token): """哈希令牌(用于存储)""" return hashlib.sha256(token.encode()).hexdigest() # ==================== CSRF保护 ==================== def generate_csrf_token(): """生成CSRF令牌""" if 'csrf_token' not in session: session['csrf_token'] = secrets.token_urlsafe(32) return session['csrf_token'] def validate_csrf_token(token): """验证CSRF令牌""" return token == session.get('csrf_token') # ==================== 内容安全 ==================== def escape_html(text): """转义HTML特殊字符(防止XSS)""" if not text: return text replacements = { '&': '&', '<': '<', '>': '>', '"': '"', "'": ''', '/': '/', } for char, escaped in replacements.items(): text = text.replace(char, escaped) return text def sanitize_sql_like_pattern(pattern): """ 清理SQL LIKE模式中的特殊字符 Args: pattern: LIKE模式字符串 Returns: str: 清理后的模式 """ # 转义LIKE中的特殊字符 pattern = pattern.replace('\\', '\\\\') pattern = pattern.replace('%', '\\%') pattern = pattern.replace('_', '\\_') return pattern # ==================== 安全配置检查 ==================== def check_security_config(): """ 检查安全配置 Returns: list: 安全问题列表 """ issues = [] # 检查SECRET_KEY from flask import current_app secret_key = current_app.config.get('SECRET_KEY') if not secret_key or len(secret_key) < 32: issues.append("SECRET_KEY过短或未设置") # 检查DEBUG模式 if current_app.config.get('DEBUG'): issues.append("DEBUG模式在生产环境应该关闭") # 检查Cookie安全设置 if not current_app.config.get('SESSION_COOKIE_HTTPONLY'): issues.append("SESSION_COOKIE_HTTPONLY应该设置为True") if not current_app.config.get('SESSION_COOKIE_SECURE'): issues.append("生产环境应该启用SESSION_COOKIE_SECURE(需要HTTPS)") return issues # ==================== 辅助函数 ==================== def get_client_ip(trust_proxy=False): """ 获取客户端真实IP地址 安全修复:默认不信任代理头,防止IP伪造绕过限流 Args: trust_proxy: 是否信任代理头(仅在已知可信代理后设置为True) Returns: str: IP地址 """ # 安全说明:X-Forwarded-For 可被伪造 # 仅在确认请求来自可信代理时才使用代理头 if trust_proxy: if request.headers.get('X-Forwarded-For'): return request.headers.get('X-Forwarded-For').split(',')[0].strip() elif request.headers.get('X-Real-IP'): return request.headers.get('X-Real-IP') # 默认使用remote_addr(更安全但可能是代理IP) return request.remote_addr def get_rate_limit_ip() -> str: """在可信代理场景下取真实IP,用于限流/风控。""" remote_addr = request.remote_addr or "" try: remote_ip = ipaddress.ip_address(remote_addr) except ValueError: remote_ip = None if remote_ip and (remote_ip.is_private or remote_ip.is_loopback or remote_ip.is_link_local): forwarded = request.headers.get("X-Forwarded-For", "") if forwarded: candidate = forwarded.split(",")[0].strip() try: ipaddress.ip_address(candidate) return candidate except ValueError: pass real_ip = request.headers.get("X-Real-IP", "").strip() if real_ip: try: ipaddress.ip_address(real_ip) return real_ip except ValueError: pass return remote_addr def is_safe_outbound_url(url: str) -> bool: """限制向内网/保留地址发起请求,降低SSRF风险。""" try: parsed = urlparse(str(url or "").strip()) except Exception: return False if parsed.scheme not in ("http", "https"): return False host = parsed.hostname if not host: return False ips = [] try: ips = [ipaddress.ip_address(host)] except ValueError: try: infos = socket.getaddrinfo(host, None) ips = [ipaddress.ip_address(info[4][0]) for info in infos] except Exception: return False for ip in ips: if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved: return False return True if __name__ == '__main__': # 测试文件路径安全 print("文件路径安全测试:") print(f" 安全路径: {is_safe_path('/tmp', 'test.txt')}") print(f" 危险路径: {is_safe_path('/tmp', '../etc/passwd')}") # 测试文件名清理 print(f"\n文件名清理: {sanitize_filename('../../../etc/passwd')}") # 测试输入验证 print("\n输入验证测试:") print(f" 用户名: {validate_username('test_user')}") print(f" 密码: {validate_password('Test123456')}") print(f" 邮箱: {validate_email('test@example.com')}") # 测试IP限流 print("\nIP限流测试:") limiter = IPRateLimiter(max_attempts=3, window_seconds=60) ip = '192.168.1.1' for i in range(5): locked = limiter.record_attempt(ip, success=False) print(f" 尝试 {i+1}: 剩余次数={limiter.get_remaining_attempts(ip)}, 是否锁定={locked}") print(f" IP被锁定: {limiter.is_locked(ip)}")