主要功能: - 多用户管理系统 - 浏览器自动化(Playwright) - 任务编排和执行 - Docker容器化部署 - 数据持久化和日志管理 技术栈: - Flask 3.0.0 - Playwright 1.40.0 - SQLite with connection pooling - Docker + Docker Compose 部署说明详见README.md
436 lines
11 KiB
Python
Executable File
436 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
安全工具模块
|
||
提供各种安全相关的功能
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import time
|
||
import hashlib
|
||
import secrets
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
from functools import wraps
|
||
from flask import request, jsonify, session
|
||
from collections import defaultdict
|
||
import threading
|
||
|
||
|
||
# ==================== 文件路径安全 ====================
|
||
|
||
def is_safe_path(basedir, path, follow_symlinks=True):
|
||
"""
|
||
检查路径是否安全(防止路径遍历攻击)
|
||
|
||
Args:
|
||
basedir: 基础目录
|
||
path: 要检查的路径
|
||
follow_symlinks: 是否跟随符号链接
|
||
|
||
Returns:
|
||
bool: 路径是否安全
|
||
"""
|
||
# 检查路径中是否包含危险字符
|
||
if '..' in path or path.startswith('/') or path.startswith('\\'):
|
||
return False
|
||
|
||
# 解析路径
|
||
if follow_symlinks:
|
||
matchpath = os.path.realpath(os.path.join(basedir, path))
|
||
else:
|
||
matchpath = os.path.abspath(os.path.join(basedir, path))
|
||
|
||
# 检查是否在基础目录内
|
||
return matchpath.startswith(os.path.abspath(basedir))
|
||
|
||
|
||
def sanitize_filename(filename):
|
||
"""
|
||
清理文件名,移除危险字符
|
||
|
||
Args:
|
||
filename: 原始文件名
|
||
|
||
Returns:
|
||
str: 清理后的文件名
|
||
"""
|
||
# 移除路径分隔符
|
||
filename = filename.replace('/', '_').replace('\\', '_')
|
||
|
||
# 只保留安全字符
|
||
filename = re.sub(r'[^a-zA-Z0-9._-]', '_', filename)
|
||
|
||
# 限制长度
|
||
if len(filename) > 255:
|
||
name, ext = os.path.splitext(filename)
|
||
filename = name[:255-len(ext)] + ext
|
||
|
||
return filename
|
||
|
||
|
||
# ==================== IP限流和黑名单 ====================
|
||
|
||
class IPRateLimiter:
|
||
"""IP访问频率限制器"""
|
||
|
||
def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600):
|
||
"""
|
||
初始化限流器
|
||
|
||
Args:
|
||
max_attempts: 时间窗口内的最大尝试次数
|
||
window_seconds: 时间窗口大小(秒)
|
||
lock_duration: 锁定时长(秒)
|
||
"""
|
||
self.max_attempts = max_attempts
|
||
self.window_seconds = window_seconds
|
||
self.lock_duration = lock_duration
|
||
|
||
# IP访问记录: {ip: [(timestamp, success), ...]}
|
||
self._attempts = defaultdict(list)
|
||
# IP锁定记录: {ip: lock_until_timestamp}
|
||
self._locked = {}
|
||
self._lock = threading.Lock()
|
||
|
||
def is_locked(self, ip_address):
|
||
"""
|
||
检查IP是否被锁定
|
||
|
||
Args:
|
||
ip_address: IP地址
|
||
|
||
Returns:
|
||
bool: 是否被锁定
|
||
"""
|
||
with self._lock:
|
||
if ip_address in self._locked:
|
||
if time.time() < self._locked[ip_address]:
|
||
return True
|
||
else:
|
||
# 锁定已过期,移除
|
||
del self._locked[ip_address]
|
||
return False
|
||
|
||
def record_attempt(self, ip_address, success=True):
|
||
"""
|
||
记录访问尝试
|
||
|
||
Args:
|
||
ip_address: IP地址
|
||
success: 是否成功
|
||
|
||
Returns:
|
||
bool: 是否应该锁定该IP
|
||
"""
|
||
with self._lock:
|
||
now = time.time()
|
||
|
||
# 清理过期记录
|
||
cutoff_time = now - self.window_seconds
|
||
self._attempts[ip_address] = [
|
||
(ts, succ) for ts, succ in self._attempts[ip_address]
|
||
if ts > cutoff_time
|
||
]
|
||
|
||
# 记录本次尝试
|
||
self._attempts[ip_address].append((now, success))
|
||
|
||
# 检查失败次数
|
||
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
|
||
|
||
if failed_attempts >= self.max_attempts:
|
||
# 锁定IP
|
||
self._locked[ip_address] = now + self.lock_duration
|
||
return True
|
||
|
||
return False
|
||
|
||
def get_remaining_attempts(self, ip_address):
|
||
"""
|
||
获取剩余尝试次数
|
||
|
||
Args:
|
||
ip_address: IP地址
|
||
|
||
Returns:
|
||
int: 剩余尝试次数
|
||
"""
|
||
with self._lock:
|
||
now = time.time()
|
||
cutoff_time = now - self.window_seconds
|
||
|
||
# 清理过期记录
|
||
self._attempts[ip_address] = [
|
||
(ts, succ) for ts, succ in self._attempts[ip_address]
|
||
if ts > cutoff_time
|
||
]
|
||
|
||
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
|
||
return max(0, self.max_attempts - failed_attempts)
|
||
|
||
def cleanup(self):
|
||
"""清理过期数据"""
|
||
with self._lock:
|
||
now = time.time()
|
||
|
||
# 清理过期的尝试记录
|
||
cutoff_time = now - self.window_seconds
|
||
for ip in list(self._attempts.keys()):
|
||
self._attempts[ip] = [
|
||
(ts, succ) for ts, succ in self._attempts[ip]
|
||
if ts > cutoff_time
|
||
]
|
||
if not self._attempts[ip]:
|
||
del self._attempts[ip]
|
||
|
||
# 清理过期的锁定
|
||
for ip in list(self._locked.keys()):
|
||
if now >= self._locked[ip]:
|
||
del self._locked[ip]
|
||
|
||
|
||
# 全局IP限流器实例
|
||
ip_rate_limiter = IPRateLimiter()
|
||
|
||
|
||
def require_ip_not_locked(f):
|
||
"""装饰器:检查IP是否被锁定"""
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
ip_address = request.remote_addr
|
||
|
||
if ip_rate_limiter.is_locked(ip_address):
|
||
return jsonify({
|
||
"error": "由于多次失败尝试,您的IP已被临时锁定",
|
||
"locked_until": ip_rate_limiter._locked.get(ip_address, 0)
|
||
}), 429
|
||
|
||
return f(*args, **kwargs)
|
||
|
||
return decorated_function
|
||
|
||
|
||
# ==================== 输入验证 ====================
|
||
|
||
def validate_username(username):
|
||
"""
|
||
验证用户名格式
|
||
|
||
Args:
|
||
username: 用户名
|
||
|
||
Returns:
|
||
tuple: (is_valid, error_message)
|
||
"""
|
||
if not username:
|
||
return False, "用户名不能为空"
|
||
|
||
if len(username) < 3:
|
||
return False, "用户名长度不能少于3个字符"
|
||
|
||
if len(username) > 50:
|
||
return False, "用户名长度不能超过50个字符"
|
||
|
||
# 只允许字母、数字、下划线、中文
|
||
if not re.match(r'^[\w\u4e00-\u9fa5]+$', username):
|
||
return False, "用户名只能包含字母、数字、下划线和中文字符"
|
||
|
||
return True, None
|
||
|
||
|
||
def validate_password(password):
|
||
"""
|
||
验证密码强度
|
||
|
||
Args:
|
||
password: 密码
|
||
|
||
Returns:
|
||
tuple: (is_valid, error_message)
|
||
"""
|
||
if not password:
|
||
return False, "密码不能为空"
|
||
|
||
if len(password) < 6:
|
||
return False, "密码长度不能少于6个字符"
|
||
|
||
if len(password) > 128:
|
||
return False, "密码长度不能超过128个字符"
|
||
|
||
# 可选:强制密码复杂度
|
||
# has_upper = bool(re.search(r'[A-Z]', password))
|
||
# has_lower = bool(re.search(r'[a-z]', password))
|
||
# has_digit = bool(re.search(r'\d', password))
|
||
#
|
||
# if not (has_upper and has_lower and has_digit):
|
||
# return False, "密码必须包含大写字母、小写字母和数字"
|
||
|
||
return True, None
|
||
|
||
|
||
def validate_email(email):
|
||
"""
|
||
验证邮箱格式
|
||
|
||
Args:
|
||
email: 邮箱地址
|
||
|
||
Returns:
|
||
tuple: (is_valid, error_message)
|
||
"""
|
||
if not email:
|
||
return True, None # 邮箱可选
|
||
|
||
# 简单的邮箱正则
|
||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||
if not re.match(pattern, email):
|
||
return False, "邮箱格式不正确"
|
||
|
||
if len(email) > 255:
|
||
return False, "邮箱长度不能超过255个字符"
|
||
|
||
return True, None
|
||
|
||
|
||
# ==================== 会话安全 ====================
|
||
|
||
def generate_session_token():
|
||
"""生成安全的会话令牌"""
|
||
return secrets.token_urlsafe(32)
|
||
|
||
|
||
def hash_token(token):
|
||
"""哈希令牌(用于存储)"""
|
||
return hashlib.sha256(token.encode()).hexdigest()
|
||
|
||
|
||
# ==================== CSRF保护 ====================
|
||
|
||
def generate_csrf_token():
|
||
"""生成CSRF令牌"""
|
||
if 'csrf_token' not in session:
|
||
session['csrf_token'] = secrets.token_urlsafe(32)
|
||
return session['csrf_token']
|
||
|
||
|
||
def validate_csrf_token(token):
|
||
"""验证CSRF令牌"""
|
||
return token == session.get('csrf_token')
|
||
|
||
|
||
# ==================== 内容安全 ====================
|
||
|
||
def escape_html(text):
|
||
"""转义HTML特殊字符(防止XSS)"""
|
||
if not text:
|
||
return text
|
||
|
||
replacements = {
|
||
'&': '&',
|
||
'<': '<',
|
||
'>': '>',
|
||
'"': '"',
|
||
"'": ''',
|
||
'/': '/',
|
||
}
|
||
|
||
for char, escaped in replacements.items():
|
||
text = text.replace(char, escaped)
|
||
|
||
return text
|
||
|
||
|
||
def sanitize_sql_like_pattern(pattern):
|
||
"""
|
||
清理SQL LIKE模式中的特殊字符
|
||
|
||
Args:
|
||
pattern: LIKE模式字符串
|
||
|
||
Returns:
|
||
str: 清理后的模式
|
||
"""
|
||
# 转义LIKE中的特殊字符
|
||
pattern = pattern.replace('\\', '\\\\')
|
||
pattern = pattern.replace('%', '\\%')
|
||
pattern = pattern.replace('_', '\\_')
|
||
return pattern
|
||
|
||
|
||
# ==================== 安全配置检查 ====================
|
||
|
||
def check_security_config():
|
||
"""
|
||
检查安全配置
|
||
|
||
Returns:
|
||
list: 安全问题列表
|
||
"""
|
||
issues = []
|
||
|
||
# 检查SECRET_KEY
|
||
from flask import current_app
|
||
secret_key = current_app.config.get('SECRET_KEY')
|
||
if not secret_key or len(secret_key) < 32:
|
||
issues.append("SECRET_KEY过短或未设置")
|
||
|
||
# 检查DEBUG模式
|
||
if current_app.config.get('DEBUG'):
|
||
issues.append("DEBUG模式在生产环境应该关闭")
|
||
|
||
# 检查Cookie安全设置
|
||
if not current_app.config.get('SESSION_COOKIE_HTTPONLY'):
|
||
issues.append("SESSION_COOKIE_HTTPONLY应该设置为True")
|
||
|
||
if not current_app.config.get('SESSION_COOKIE_SECURE'):
|
||
issues.append("生产环境应该启用SESSION_COOKIE_SECURE(需要HTTPS)")
|
||
|
||
return issues
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
def get_client_ip():
|
||
"""
|
||
获取客户端真实IP地址
|
||
|
||
Returns:
|
||
str: IP地址
|
||
"""
|
||
# 检查代理头
|
||
if request.headers.get('X-Forwarded-For'):
|
||
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
||
elif request.headers.get('X-Real-IP'):
|
||
return request.headers.get('X-Real-IP')
|
||
else:
|
||
return request.remote_addr
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试文件路径安全
|
||
print("文件路径安全测试:")
|
||
print(f" 安全路径: {is_safe_path('/tmp', 'test.txt')}")
|
||
print(f" 危险路径: {is_safe_path('/tmp', '../etc/passwd')}")
|
||
|
||
# 测试文件名清理
|
||
print(f"\n文件名清理: {sanitize_filename('../../../etc/passwd')}")
|
||
|
||
# 测试输入验证
|
||
print("\n输入验证测试:")
|
||
print(f" 用户名: {validate_username('test_user')}")
|
||
print(f" 密码: {validate_password('Test123456')}")
|
||
print(f" 邮箱: {validate_email('test@example.com')}")
|
||
|
||
# 测试IP限流
|
||
print("\nIP限流测试:")
|
||
limiter = IPRateLimiter(max_attempts=3, window_seconds=60)
|
||
ip = '192.168.1.1'
|
||
|
||
for i in range(5):
|
||
locked = limiter.record_attempt(ip, success=False)
|
||
print(f" 尝试 {i+1}: 剩余次数={limiter.get_remaining_attempts(ip)}, 是否锁定={locked}")
|
||
|
||
print(f" IP被锁定: {limiter.is_locked(ip)}")
|