180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
|
||
import db_pool
|
||
from crypto_utils import decrypt_password, encrypt_password
|
||
from db.utils import get_cst_now_str
|
||
|
||
|
||
def create_account(user_id, account_id, username, password, remember=True, remark=""):
|
||
"""创建账号(密码加密存储)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cst_time = get_cst_now_str()
|
||
encrypted_password = encrypt_password(password)
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
|
||
)
|
||
conn.commit()
|
||
return cursor.lastrowid
|
||
|
||
|
||
def get_user_accounts(user_id):
|
||
"""获取用户的所有账号(自动解密密码)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
|
||
accounts = []
|
||
for row in cursor.fetchall():
|
||
account = dict(row)
|
||
account["password"] = decrypt_password(account.get("password", ""))
|
||
accounts.append(account)
|
||
return accounts
|
||
|
||
|
||
def get_account(account_id):
|
||
"""获取单个账号(自动解密密码)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
|
||
row = cursor.fetchone()
|
||
if row:
|
||
account = dict(row)
|
||
account["password"] = decrypt_password(account.get("password", ""))
|
||
return account
|
||
return None
|
||
|
||
|
||
def update_account_remark(account_id, remark):
|
||
"""更新账号备注"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE accounts SET remark = ? WHERE id = ?", (remark, account_id))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def delete_account(account_id):
|
||
"""删除账号"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM accounts WHERE id = ?", (account_id,))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def increment_account_login_fail(account_id, error_message):
|
||
"""增加账号登录失败次数,如果达到3次则暂停账号"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute("SELECT login_fail_count FROM accounts WHERE id = ?", (account_id,))
|
||
row = cursor.fetchone()
|
||
if not row:
|
||
return False
|
||
|
||
fail_count = (row["login_fail_count"] or 0) + 1
|
||
|
||
if fail_count >= 3:
|
||
cursor.execute(
|
||
"""
|
||
UPDATE accounts
|
||
SET login_fail_count = ?,
|
||
last_login_error = ?,
|
||
status = 'suspended'
|
||
WHERE id = ?
|
||
""",
|
||
(fail_count, error_message, account_id),
|
||
)
|
||
conn.commit()
|
||
return True
|
||
|
||
cursor.execute(
|
||
"""
|
||
UPDATE accounts
|
||
SET login_fail_count = ?,
|
||
last_login_error = ?
|
||
WHERE id = ?
|
||
""",
|
||
(fail_count, error_message, account_id),
|
||
)
|
||
conn.commit()
|
||
return False
|
||
|
||
|
||
def reset_account_login_status(account_id):
|
||
"""重置账号登录状态(修改密码后调用)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
UPDATE accounts
|
||
SET login_fail_count = 0,
|
||
last_login_error = NULL,
|
||
status = 'active'
|
||
WHERE id = ?
|
||
""",
|
||
(account_id,),
|
||
)
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def get_account_status(account_id):
|
||
"""获取账号状态信息"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
SELECT status, login_fail_count, last_login_error
|
||
FROM accounts
|
||
WHERE id = ?
|
||
""",
|
||
(account_id,),
|
||
)
|
||
return cursor.fetchone()
|
||
|
||
|
||
def get_account_status_batch(account_ids):
|
||
"""批量获取账号状态信息"""
|
||
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
|
||
if not account_ids:
|
||
return {}
|
||
|
||
results = {}
|
||
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
for idx in range(0, len(account_ids), chunk_size):
|
||
chunk = account_ids[idx : idx + chunk_size]
|
||
placeholders = ",".join("?" for _ in chunk)
|
||
cursor.execute(
|
||
f"""
|
||
SELECT id, status, login_fail_count, last_login_error
|
||
FROM accounts
|
||
WHERE id IN ({placeholders})
|
||
""",
|
||
chunk,
|
||
)
|
||
for row in cursor.fetchall():
|
||
row_dict = dict(row)
|
||
account_id = str(row_dict.pop("id", ""))
|
||
if account_id:
|
||
results[account_id] = row_dict
|
||
|
||
return results
|
||
|
||
|
||
def delete_user_accounts(user_id):
|
||
"""删除用户的所有账号"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM accounts WHERE user_id = ?", (user_id,))
|
||
conn.commit()
|
||
return cursor.rowcount
|