refactor: optimize structure, stability and runtime performance
This commit is contained in:
@@ -6,19 +6,51 @@ import db_pool
|
||||
from crypto_utils import decrypt_password, encrypt_password
|
||||
from db.utils import get_cst_now_str
|
||||
|
||||
_ACCOUNT_STATUS_QUERY_SQL = """
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
"""
|
||||
|
||||
|
||||
def _decode_account_password(account_dict: dict) -> dict:
|
||||
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
|
||||
return account_dict
|
||||
|
||||
|
||||
def _normalize_account_ids(account_ids) -> list[str]:
|
||||
normalized = []
|
||||
seen = set()
|
||||
for account_id in account_ids or []:
|
||||
if not account_id:
|
||||
continue
|
||||
account_key = str(account_id)
|
||||
if account_key in seen:
|
||||
continue
|
||||
seen.add(account_key)
|
||||
normalized.append(account_key)
|
||||
return normalized
|
||||
|
||||
|
||||
def create_account(user_id, account_id, username, password, remember=True, remark=""):
|
||||
"""创建账号(密码加密存储)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
encrypted_password = encrypt_password(password)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
|
||||
(
|
||||
account_id,
|
||||
user_id,
|
||||
username,
|
||||
encrypted_password,
|
||||
1 if remember else 0,
|
||||
remark,
|
||||
get_cst_now_str(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
|
||||
accounts = []
|
||||
for row in cursor.fetchall():
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
accounts.append(account)
|
||||
return accounts
|
||||
return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_account(account_id):
|
||||
@@ -43,11 +70,9 @@ def get_account(account_id):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
return account
|
||||
return None
|
||||
if not row:
|
||||
return None
|
||||
return _decode_account_password(dict(row))
|
||||
|
||||
|
||||
def update_account_remark(account_id, remark):
|
||||
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
|
||||
if not row:
|
||||
return False
|
||||
|
||||
fail_count = (row["login_fail_count"] or 0) + 1
|
||||
|
||||
if fail_count >= 3:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?,
|
||||
status = 'suspended'
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
fail_count = int(row["login_fail_count"] or 0) + 1
|
||||
is_suspended = fail_count >= 3
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?
|
||||
last_login_error = ?,
|
||||
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
(fail_count, error_message, 1 if is_suspended else 0, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return False
|
||||
return is_suspended
|
||||
|
||||
|
||||
def reset_account_login_status(account_id):
|
||||
@@ -129,29 +142,22 @@ def get_account_status(account_id):
|
||||
"""获取账号状态信息"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
""",
|
||||
(account_id,),
|
||||
)
|
||||
cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def get_account_status_batch(account_ids):
|
||||
"""批量获取账号状态信息"""
|
||||
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
|
||||
if not account_ids:
|
||||
normalized_ids = _normalize_account_ids(account_ids)
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
for idx in range(0, len(account_ids), chunk_size):
|
||||
chunk = account_ids[idx : idx + chunk_size]
|
||||
for idx in range(0, len(normalized_ids), chunk_size):
|
||||
chunk = normalized_ids[idx : idx + chunk_size]
|
||||
placeholders = ",".join("?" for _ in chunk)
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
Reference in New Issue
Block a user