Files
zsglpt/app_security.py
Yu Yon 0fd7137cea Initial commit: 知识管理平台
主要功能:
- 多用户管理系统
- 浏览器自动化(Playwright)
- 任务编排和执行
- Docker容器化部署
- 数据持久化和日志管理

技术栈:
- Flask 3.0.0
- Playwright 1.40.0
- SQLite with connection pooling
- Docker + Docker Compose

部署说明详见README.md
2025-11-16 19:03:07 +08:00

436 lines
11 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
安全工具模块
提供各种安全相关的功能
"""
import os
import re
import time
import hashlib
import secrets
from pathlib import Path
from typing import Optional
from functools import wraps
from flask import request, jsonify, session
from collections import defaultdict
import threading
# ==================== 文件路径安全 ====================
def is_safe_path(basedir, path, follow_symlinks=True):
"""
检查路径是否安全(防止路径遍历攻击)
Args:
basedir: 基础目录
path: 要检查的路径
follow_symlinks: 是否跟随符号链接
Returns:
bool: 路径是否安全
"""
# 检查路径中是否包含危险字符
if '..' in path or path.startswith('/') or path.startswith('\\'):
return False
# 解析路径
if follow_symlinks:
matchpath = os.path.realpath(os.path.join(basedir, path))
else:
matchpath = os.path.abspath(os.path.join(basedir, path))
# 检查是否在基础目录内
return matchpath.startswith(os.path.abspath(basedir))
def sanitize_filename(filename):
"""
清理文件名,移除危险字符
Args:
filename: 原始文件名
Returns:
str: 清理后的文件名
"""
# 移除路径分隔符
filename = filename.replace('/', '_').replace('\\', '_')
# 只保留安全字符
filename = re.sub(r'[^a-zA-Z0-9._-]', '_', filename)
# 限制长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename
# ==================== IP限流和黑名单 ====================
class IPRateLimiter:
"""IP访问频率限制器"""
def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600):
"""
初始化限流器
Args:
max_attempts: 时间窗口内的最大尝试次数
window_seconds: 时间窗口大小(秒)
lock_duration: 锁定时长(秒)
"""
self.max_attempts = max_attempts
self.window_seconds = window_seconds
self.lock_duration = lock_duration
# IP访问记录: {ip: [(timestamp, success), ...]}
self._attempts = defaultdict(list)
# IP锁定记录: {ip: lock_until_timestamp}
self._locked = {}
self._lock = threading.Lock()
def is_locked(self, ip_address):
"""
检查IP是否被锁定
Args:
ip_address: IP地址
Returns:
bool: 是否被锁定
"""
with self._lock:
if ip_address in self._locked:
if time.time() < self._locked[ip_address]:
return True
else:
# 锁定已过期,移除
del self._locked[ip_address]
return False
def record_attempt(self, ip_address, success=True):
"""
记录访问尝试
Args:
ip_address: IP地址
success: 是否成功
Returns:
bool: 是否应该锁定该IP
"""
with self._lock:
now = time.time()
# 清理过期记录
cutoff_time = now - self.window_seconds
self._attempts[ip_address] = [
(ts, succ) for ts, succ in self._attempts[ip_address]
if ts > cutoff_time
]
# 记录本次尝试
self._attempts[ip_address].append((now, success))
# 检查失败次数
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
if failed_attempts >= self.max_attempts:
# 锁定IP
self._locked[ip_address] = now + self.lock_duration
return True
return False
def get_remaining_attempts(self, ip_address):
"""
获取剩余尝试次数
Args:
ip_address: IP地址
Returns:
int: 剩余尝试次数
"""
with self._lock:
now = time.time()
cutoff_time = now - self.window_seconds
# 清理过期记录
self._attempts[ip_address] = [
(ts, succ) for ts, succ in self._attempts[ip_address]
if ts > cutoff_time
]
failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ)
return max(0, self.max_attempts - failed_attempts)
def cleanup(self):
"""清理过期数据"""
with self._lock:
now = time.time()
# 清理过期的尝试记录
cutoff_time = now - self.window_seconds
for ip in list(self._attempts.keys()):
self._attempts[ip] = [
(ts, succ) for ts, succ in self._attempts[ip]
if ts > cutoff_time
]
if not self._attempts[ip]:
del self._attempts[ip]
# 清理过期的锁定
for ip in list(self._locked.keys()):
if now >= self._locked[ip]:
del self._locked[ip]
# 全局IP限流器实例
ip_rate_limiter = IPRateLimiter()
def require_ip_not_locked(f):
"""装饰器检查IP是否被锁定"""
@wraps(f)
def decorated_function(*args, **kwargs):
ip_address = request.remote_addr
if ip_rate_limiter.is_locked(ip_address):
return jsonify({
"error": "由于多次失败尝试您的IP已被临时锁定",
"locked_until": ip_rate_limiter._locked.get(ip_address, 0)
}), 429
return f(*args, **kwargs)
return decorated_function
# ==================== 输入验证 ====================
def validate_username(username):
"""
验证用户名格式
Args:
username: 用户名
Returns:
tuple: (is_valid, error_message)
"""
if not username:
return False, "用户名不能为空"
if len(username) < 3:
return False, "用户名长度不能少于3个字符"
if len(username) > 50:
return False, "用户名长度不能超过50个字符"
# 只允许字母、数字、下划线、中文
if not re.match(r'^[\w\u4e00-\u9fa5]+$', username):
return False, "用户名只能包含字母、数字、下划线和中文字符"
return True, None
def validate_password(password):
"""
验证密码强度
Args:
password: 密码
Returns:
tuple: (is_valid, error_message)
"""
if not password:
return False, "密码不能为空"
if len(password) < 6:
return False, "密码长度不能少于6个字符"
if len(password) > 128:
return False, "密码长度不能超过128个字符"
# 可选:强制密码复杂度
# has_upper = bool(re.search(r'[A-Z]', password))
# has_lower = bool(re.search(r'[a-z]', password))
# has_digit = bool(re.search(r'\d', password))
#
# if not (has_upper and has_lower and has_digit):
# return False, "密码必须包含大写字母、小写字母和数字"
return True, None
def validate_email(email):
"""
验证邮箱格式
Args:
email: 邮箱地址
Returns:
tuple: (is_valid, error_message)
"""
if not email:
return True, None # 邮箱可选
# 简单的邮箱正则
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, email):
return False, "邮箱格式不正确"
if len(email) > 255:
return False, "邮箱长度不能超过255个字符"
return True, None
# ==================== 会话安全 ====================
def generate_session_token():
"""生成安全的会话令牌"""
return secrets.token_urlsafe(32)
def hash_token(token):
"""哈希令牌(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
# ==================== CSRF保护 ====================
def generate_csrf_token():
"""生成CSRF令牌"""
if 'csrf_token' not in session:
session['csrf_token'] = secrets.token_urlsafe(32)
return session['csrf_token']
def validate_csrf_token(token):
"""验证CSRF令牌"""
return token == session.get('csrf_token')
# ==================== 内容安全 ====================
def escape_html(text):
"""转义HTML特殊字符防止XSS"""
if not text:
return text
replacements = {
'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
'"': '&quot;',
"'": '&#x27;',
'/': '&#x2F;',
}
for char, escaped in replacements.items():
text = text.replace(char, escaped)
return text
def sanitize_sql_like_pattern(pattern):
"""
清理SQL LIKE模式中的特殊字符
Args:
pattern: LIKE模式字符串
Returns:
str: 清理后的模式
"""
# 转义LIKE中的特殊字符
pattern = pattern.replace('\\', '\\\\')
pattern = pattern.replace('%', '\\%')
pattern = pattern.replace('_', '\\_')
return pattern
# ==================== 安全配置检查 ====================
def check_security_config():
"""
检查安全配置
Returns:
list: 安全问题列表
"""
issues = []
# 检查SECRET_KEY
from flask import current_app
secret_key = current_app.config.get('SECRET_KEY')
if not secret_key or len(secret_key) < 32:
issues.append("SECRET_KEY过短或未设置")
# 检查DEBUG模式
if current_app.config.get('DEBUG'):
issues.append("DEBUG模式在生产环境应该关闭")
# 检查Cookie安全设置
if not current_app.config.get('SESSION_COOKIE_HTTPONLY'):
issues.append("SESSION_COOKIE_HTTPONLY应该设置为True")
if not current_app.config.get('SESSION_COOKIE_SECURE'):
issues.append("生产环境应该启用SESSION_COOKIE_SECURE需要HTTPS")
return issues
# ==================== 辅助函数 ====================
def get_client_ip():
"""
获取客户端真实IP地址
Returns:
str: IP地址
"""
# 检查代理头
if request.headers.get('X-Forwarded-For'):
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
return request.headers.get('X-Real-IP')
else:
return request.remote_addr
if __name__ == '__main__':
# 测试文件路径安全
print("文件路径安全测试:")
print(f" 安全路径: {is_safe_path('/tmp', 'test.txt')}")
print(f" 危险路径: {is_safe_path('/tmp', '../etc/passwd')}")
# 测试文件名清理
print(f"\n文件名清理: {sanitize_filename('../../../etc/passwd')}")
# 测试输入验证
print("\n输入验证测试:")
print(f" 用户名: {validate_username('test_user')}")
print(f" 密码: {validate_password('Test123456')}")
print(f" 邮箱: {validate_email('test@example.com')}")
# 测试IP限流
print("\nIP限流测试:")
limiter = IPRateLimiter(max_attempts=3, window_seconds=60)
ip = '192.168.1.1'
for i in range(5):
locked = limiter.record_attempt(ip, success=False)
print(f" 尝试 {i+1}: 剩余次数={limiter.get_remaining_attempts(ip)}, 是否锁定={locked}")
print(f" IP被锁定: {limiter.is_locked(ip)}")