Files
zsglpt/app_security.py

665 lines
18 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
安全工具模块
提供各种安全相关的功能
"""
import os
import re
import time
import hashlib
import hmac
import secrets
import ipaddress
import socket
from pathlib import Path
from typing import Optional
from functools import wraps
from urllib.parse import urlparse
from flask import request, jsonify, session
from collections import defaultdict
import threading
# ==================== 文件路径安全 ====================
def is_safe_path(basedir, path, follow_symlinks=True):
"""
检查路径是否安全(防止路径遍历攻击)
Args:
basedir: 基础目录
path: 要检查的路径
follow_symlinks: 是否跟随符号链接
Returns:
bool: 路径是否安全
"""
# 检查路径中是否包含危险字符
if '..' in path or path.startswith('/') or path.startswith('\\'):
return False
# 解析路径
if follow_symlinks:
matchpath = os.path.realpath(os.path.join(basedir, path))
else:
matchpath = os.path.abspath(os.path.join(basedir, path))
# 检查是否在基础目录内
return matchpath.startswith(os.path.abspath(basedir))
def sanitize_filename(filename):
"""
清理文件名,移除危险字符
Args:
filename: 原始文件名
Returns:
str: 清理后的文件名
"""
# 移除路径分隔符
filename = filename.replace('/', '_').replace('\\', '_')
# 只保留安全字符
filename = re.sub(r'[^a-zA-Z0-9._-]', '_', filename)
# 限制长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename
# ==================== IP限流和黑名单 ====================
class IPRateLimiter:
"""IP访问频率限制器"""
def __init__(
self,
max_attempts=10,
window_seconds=3600,
lock_duration=3600,
max_tracked_ips=20000,
):
"""
初始化限流器
Args:
max_attempts: 时间窗口内的最大尝试次数
window_seconds: 时间窗口大小(秒)
lock_duration: 锁定时长(秒)
"""
self.max_attempts = max_attempts
self.window_seconds = window_seconds
self.lock_duration = lock_duration
self.max_tracked_ips = max(1000, int(max_tracked_ips or 0))
# IP访问记录: {ip: [(timestamp, success), ...]}
self._attempts = defaultdict(list)
# IP锁定记录: {ip: lock_until_timestamp}
self._locked = {}
self._lock = threading.Lock()
def _prune_if_oversized(self, now_ts: float) -> None:
"""限制内部映射大小避免在高频随机IP攻击下持续膨胀。"""
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
cutoff_time = now_ts - self.window_seconds
for ip in list(self._attempts.keys()):
self._attempts[ip] = [
(ts, succ) for ts, succ in self._attempts[ip]
if ts > cutoff_time
]
if not self._attempts[ip]:
del self._attempts[ip]
for ip in list(self._locked.keys()):
if now_ts >= self._locked[ip]:
del self._locked[ip]
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
# 优先按“最近访问时间最早”淘汰 attempts 中的 IP 记录。
overflow = tracked - self.max_tracked_ips
oldest = []
for ip, attempt_items in self._attempts.items():
if attempt_items:
oldest.append((attempt_items[-1][0], ip))
else:
oldest.append((0.0, ip))
oldest.sort(key=lambda item: item[0])
removed = 0
for _, ip in oldest:
self._attempts.pop(ip, None)
self._locked.pop(ip, None)
removed += 1
if removed >= overflow:
break
def is_locked(self, ip_address):
"""
检查IP是否被锁定
Args:
ip_address: IP地址
Returns:
bool: 是否被锁定
"""
with self._lock:
if ip_address in self._locked:
if time.time() < self._locked[ip_address]:
return True
else:
# 锁定已过期,移除
del self._locked[ip_address]
return False
def record_attempt(self, ip_address, success=True):
"""
记录访问尝试
Args:
ip_address: IP地址
success: 是否成功
Returns:
bool: 是否应该锁定该IP
"""
with self._lock:
now = time.time()
self._prune_if_oversized(now)
# 清理过期记录
cutoff_time = now - self.window_seconds
self._attempts[ip_address] = [
(ts, succ) for ts, succ in self._attempts[ip_address]
if ts > cutoff_time
]
# 记录本次尝试
self._attempts[ip_address].append((now, success))
# 检查失败次数
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
if failed_attempts >= self.max_attempts:
# 锁定IP
self._locked[ip_address] = now + self.lock_duration
return True
return False
def get_remaining_attempts(self, ip_address):
"""
获取剩余尝试次数
Args:
ip_address: IP地址
Returns:
int: 剩余尝试次数
"""
with self._lock:
now = time.time()
cutoff_time = now - self.window_seconds
# 清理过期记录
self._attempts[ip_address] = [
(ts, succ) for ts, succ in self._attempts[ip_address]
if ts > cutoff_time
]
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
return max(0, self.max_attempts - failed_attempts)
def cleanup(self):
"""清理过期数据"""
with self._lock:
now = time.time()
# 清理过期的尝试记录
cutoff_time = now - self.window_seconds
for ip in list(self._attempts.keys()):
self._attempts[ip] = [
(ts, succ) for ts, succ in self._attempts[ip]
if ts > cutoff_time
]
if not self._attempts[ip]:
del self._attempts[ip]
# 清理过期的锁定
for ip in list(self._locked.keys()):
if now >= self._locked[ip]:
del self._locked[ip]
# 全局IP限流器实例
ip_rate_limiter = IPRateLimiter()
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
_TRUST_PROXY_HEADERS = str(os.environ.get("TRUST_PROXY_HEADERS", "false") or "").strip().lower() in _TRUTHY_VALUES
def require_ip_not_locked(f):
"""装饰器检查IP是否被锁定"""
@wraps(f)
def decorated_function(*args, **kwargs):
ip_address = get_rate_limit_ip()
# P0 / O-01统一使用 services.state 的线程安全限流状态
try:
from services.state import check_ip_rate_limit, safe_get_ip_lock_until
allowed, error_msg = check_ip_rate_limit(ip_address)
if not allowed:
return (
jsonify(
{
"error": error_msg or "由于多次失败尝试您的IP已被临时锁定",
"locked_until": safe_get_ip_lock_until(ip_address),
}
),
429,
)
except Exception:
# 兜底:沿用旧实现(避免极端情况下阻断业务)
if ip_rate_limiter.is_locked(ip_address):
return (
jsonify(
{
"error": "由于多次失败尝试您的IP已被临时锁定",
"locked_until": ip_rate_limiter._locked.get(ip_address, 0),
}
),
429,
)
return f(*args, **kwargs)
return decorated_function
# ==================== 输入验证 ====================
def validate_username(username):
"""
验证用户名格式
Args:
username: 用户名
Returns:
tuple: (is_valid, error_message)
"""
if not username:
return False, "用户名不能为空"
if len(username) < 3:
return False, "用户名长度不能少于3个字符"
if len(username) > 50:
return False, "用户名长度不能超过50个字符"
# 只允许字母、数字、下划线、中文
if not re.match(r'^[\w\u4e00-\u9fa5]+$', username):
return False, "用户名只能包含字母、数字、下划线和中文字符"
# Bug fix: 过滤零宽字符和其他不可见字符
# 检查是否包含不可见/控制字符
import unicodedata
for char in username:
category = unicodedata.category(char)
# Cf = 格式字符 (包括零宽字符), Cc = 控制字符
if category in ('Cf', 'Cc'):
return False, "用户名不能包含不可见字符"
return True, None
def validate_password(password, require_complexity=True):
"""
验证密码强度
安全修复:增强密码强度要求
Args:
password: 密码
require_complexity: 是否要求复杂度默认True
Returns:
tuple: (is_valid, error_message)
"""
if not password:
return False, "密码不能为空"
if len(password) < 8: # 安全修复最少8位
return False, "密码长度不能少于8个字符"
if len(password) > 128:
return False, "密码长度不能超过128个字符"
# 安全修复:启用密码复杂度要求
if require_complexity:
has_letter = bool(re.search(r'[a-zA-Z]', password))
has_digit = bool(re.search(r'\d', password))
if not (has_letter and has_digit):
return False, "密码必须包含字母和数字"
return True, None
def validate_email(email):
"""
验证邮箱格式
Args:
email: 邮箱地址
Returns:
tuple: (is_valid, error_message)
"""
if not email:
return True, None # 邮箱可选
# 简单的邮箱正则
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, email):
return False, "邮箱格式不正确"
if len(email) > 255:
return False, "邮箱长度不能超过255个字符"
return True, None
# ==================== 会话安全 ====================
def generate_session_token():
"""生成安全的会话令牌"""
return secrets.token_urlsafe(32)
def hash_token(token):
"""哈希令牌(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
# ==================== CSRF保护 ====================
def generate_csrf_token():
"""生成CSRF令牌"""
if 'csrf_token' not in session:
session['csrf_token'] = secrets.token_urlsafe(32)
return session['csrf_token']
def validate_csrf_token(token):
"""验证CSRF令牌"""
expected = session.get("csrf_token")
if (token is None) or (expected is None):
return False
provided_text = str(token or "")
expected_text = str(expected or "")
if (not provided_text) or (not expected_text):
return False
return hmac.compare_digest(
provided_text.encode("utf-8"),
expected_text.encode("utf-8"),
)
# ==================== 内容安全 ====================
def escape_html(text):
"""转义HTML特殊字符防止XSS"""
if not text:
return text
replacements = {
'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
'"': '&quot;',
"'": '&#x27;',
'/': '&#x2F;',
}
for char, escaped in replacements.items():
text = text.replace(char, escaped)
return text
def sanitize_sql_like_pattern(pattern):
"""
清理SQL LIKE模式中的特殊字符
Args:
pattern: LIKE模式字符串
Returns:
str: 清理后的模式
"""
# 转义LIKE中的特殊字符
pattern = pattern.replace('\\', '\\\\')
pattern = pattern.replace('%', '\\%')
pattern = pattern.replace('_', '\\_')
return pattern
# ==================== 安全配置检查 ====================
def check_security_config():
"""
检查安全配置
Returns:
list: 安全问题列表
"""
issues = []
# 检查SECRET_KEY
from flask import current_app
secret_key = current_app.config.get('SECRET_KEY')
if not secret_key or len(secret_key) < 32:
issues.append("SECRET_KEY过短或未设置")
# 检查DEBUG模式
if current_app.config.get('DEBUG'):
issues.append("DEBUG模式在生产环境应该关闭")
# 检查Cookie安全设置
if not current_app.config.get('SESSION_COOKIE_HTTPONLY'):
issues.append("SESSION_COOKIE_HTTPONLY应该设置为True")
if not current_app.config.get('SESSION_COOKIE_SECURE'):
issues.append("生产环境应该启用SESSION_COOKIE_SECURE需要HTTPS")
return issues
# ==================== 辅助函数 ====================
def get_client_ip(trust_proxy=False):
"""
获取客户端真实IP地址
安全修复默认不信任代理头防止IP伪造绕过限流
Args:
trust_proxy: 是否信任代理头仅在已知可信代理后设置为True
Returns:
str: IP地址
"""
# 安全说明X-Forwarded-For 可被伪造
# 仅在确认请求来自可信代理时才使用代理头
if trust_proxy and _TRUST_PROXY_HEADERS:
if request.headers.get('X-Forwarded-For'):
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
return request.headers.get('X-Real-IP')
# 默认使用remote_addr更安全但可能是代理IP
return request.remote_addr
def _load_trusted_proxy_networks():
"""加载可信代理 CIDR 列表。"""
default_cidrs = "127.0.0.1/32,::1/128"
raw = str(os.environ.get("TRUSTED_PROXY_CIDRS", default_cidrs) or "").strip()
if not raw:
return []
networks = []
for segment in raw.split(","):
cidr_text = str(segment or "").strip()
if not cidr_text:
continue
try:
networks.append(ipaddress.ip_network(cidr_text, strict=False))
except ValueError:
continue
return networks
_TRUSTED_PROXY_NETWORKS = _load_trusted_proxy_networks()
def _parse_ip_address(candidate: str):
try:
return ipaddress.ip_address(str(candidate or "").strip())
except ValueError:
return None
def _is_trusted_proxy_ip(ip_obj) -> bool:
if ip_obj is None:
return False
for network in _TRUSTED_PROXY_NETWORKS:
try:
if ip_obj.version != network.version:
continue
if ip_obj in network:
return True
except Exception:
continue
return False
def _extract_real_ip_from_forwarded_chain() -> str | None:
"""基于 X-Forwarded-For 链反向提取最靠近应用侧的“非代理”来源 IP。"""
forwarded = str(request.headers.get("X-Forwarded-For", "") or "")
candidates = []
for segment in forwarded.split(","):
ip_text = str(segment or "").strip()
ip_obj = _parse_ip_address(ip_text)
if ip_obj is None:
continue
candidates.append((str(ip_obj), ip_obj))
# 若存在 X-Forwarded-For按“从右到左”剥离可信代理。
if candidates:
for ip_text, ip_obj in reversed(candidates):
if _is_trusted_proxy_ip(ip_obj):
continue
return ip_text
return candidates[0][0]
real_ip_text = str(request.headers.get("X-Real-IP", "") or "").strip()
real_ip_obj = _parse_ip_address(real_ip_text)
if real_ip_obj is None:
return None
return str(real_ip_obj)
def get_rate_limit_ip() -> str:
"""在可信代理场景下取真实IP用于限流/风控。"""
remote_addr = request.remote_addr or ""
if not _TRUST_PROXY_HEADERS:
return remote_addr
remote_ip = _parse_ip_address(remote_addr)
if remote_ip is None:
return remote_addr
# 仅当请求来自可信代理时才信任转发头。
if _is_trusted_proxy_ip(remote_ip):
forwarded_real_ip = _extract_real_ip_from_forwarded_chain()
if forwarded_real_ip:
return forwarded_real_ip
return remote_addr
def is_safe_outbound_url(url: str) -> bool:
"""限制向内网/保留地址发起请求降低SSRF风险。"""
try:
parsed = urlparse(str(url or "").strip())
except Exception:
return False
if parsed.scheme not in ("http", "https"):
return False
host = parsed.hostname
if not host:
return False
ips = []
try:
ips = [ipaddress.ip_address(host)]
except ValueError:
try:
infos = socket.getaddrinfo(host, None)
ips = [ipaddress.ip_address(info[4][0]) for info in infos]
except Exception:
return False
for ip in ips:
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved:
return False
return True
if __name__ == '__main__':
# 测试文件路径安全
print("文件路径安全测试:")
print(f" 安全路径: {is_safe_path('/tmp', 'test.txt')}")
print(f" 危险路径: {is_safe_path('/tmp', '../etc/passwd')}")
# 测试文件名清理
print(f"\n文件名清理: {sanitize_filename('../../../etc/passwd')}")
# 测试输入验证
print("\n输入验证测试:")
print(f" 用户名: {validate_username('test_user')}")
print(f" 密码: {validate_password('Test123456')}")
print(f" 邮箱: {validate_email('test@example.com')}")
# 测试IP限流
print("\nIP限流测试:")
limiter = IPRateLimiter(max_attempts=3, window_seconds=60)
ip = '192.168.1.1'
for i in range(5):
locked = limiter.record_attempt(ip, success=False)
print(f" 尝试 {i+1}: 剩余次数={limiter.get_remaining_attempts(ip)}, 是否锁定={locked}")
print(f" IP被锁定: {limiter.is_locked(ip)}")