#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import json import time from typing import Any from flask import Request from webauthn import ( generate_authentication_options, generate_registration_options, verify_authentication_response, verify_registration_response, ) from webauthn.helpers import ( base64url_to_bytes, bytes_to_base64url, options_to_json, parse_authentication_credential_json, parse_registration_credential_json, ) from webauthn.helpers.structs import ( AuthenticatorSelectionCriteria, PublicKeyCredentialHint, PublicKeyCredentialDescriptor, ResidentKeyRequirement, UserVerificationRequirement, ) MAX_PASSKEYS_PER_OWNER = 3 CHALLENGE_TTL_SECONDS = 300 DEVICE_NAME_MAX_LENGTH = 40 def normalize_device_name(value: Any) -> str: text = str(value or "").strip() if not text: return "未命名设备" if len(text) > DEVICE_NAME_MAX_LENGTH: text = text[:DEVICE_NAME_MAX_LENGTH] return text def is_challenge_valid(created_at: Any, *, now_ts: float | None = None) -> bool: try: created_ts = float(created_at) except Exception: return False if now_ts is None: now_ts = time.time() return created_ts > 0 and (now_ts - created_ts) <= CHALLENGE_TTL_SECONDS def get_rp_id(request: Request) -> str: forwarded_host = str(request.headers.get("X-Forwarded-Host", "") or "").split(",", 1)[0].strip() host = forwarded_host or str(request.host or "").strip() host = host.split(":", 1)[0].strip().lower() if not host: raise ValueError("无法确定 RP ID") return host def get_expected_origins(request: Request) -> list[str]: host = str(request.host or "").strip() if not host: raise ValueError("无法确定 Origin") forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower() scheme = forwarded_proto if forwarded_proto in {"http", "https"} else str(request.scheme or "https").lower() origin = f"{scheme}://{host}" return [origin] def encode_credential_id(raw_credential_id: bytes) -> str: return bytes_to_base64url(raw_credential_id) def decode_credential_id(credential_id: str) -> bytes: return base64url_to_bytes(str(credential_id or "")) def _to_public_key_options_json(options) -> dict[str, Any]: return json.loads(options_to_json(options)) def make_registration_options( *, rp_id: str, rp_name: str, user_name: str, user_display_name: str, user_id_bytes: bytes, exclude_credential_ids: list[str], ) -> dict[str, Any]: exclude_credentials = [ PublicKeyCredentialDescriptor(id=decode_credential_id(credential_id)) for credential_id in (exclude_credential_ids or []) if credential_id ] authenticator_selection = AuthenticatorSelectionCriteria( resident_key=ResidentKeyRequirement.PREFERRED, require_resident_key=False, user_verification=UserVerificationRequirement.PREFERRED, ) options = generate_registration_options( rp_id=rp_id, rp_name=rp_name, user_name=user_name, user_display_name=user_display_name, user_id=user_id_bytes, timeout=120000, authenticator_selection=authenticator_selection, exclude_credentials=exclude_credentials, hints=[ PublicKeyCredentialHint.CLIENT_DEVICE, PublicKeyCredentialHint.HYBRID, ], ) return _to_public_key_options_json(options) def make_authentication_options( *, rp_id: str, allow_credential_ids: list[str] | None = None, ) -> dict[str, Any]: allow_credentials = [ PublicKeyCredentialDescriptor(id=decode_credential_id(credential_id)) for credential_id in (allow_credential_ids or []) if credential_id ] allow_credentials_value = allow_credentials if allow_credentials else None options = generate_authentication_options( rp_id=rp_id, timeout=120000, allow_credentials=allow_credentials_value, user_verification=UserVerificationRequirement.PREFERRED, ) return _to_public_key_options_json(options) def verify_registration( *, credential: dict[str, Any], expected_challenge: str, expected_rp_id: str, expected_origins: list[str], ): parsed = parse_registration_credential_json(credential) return verify_registration_response( credential=parsed, expected_challenge=base64url_to_bytes(expected_challenge), expected_rp_id=expected_rp_id, expected_origin=expected_origins, require_user_verification=True, ) def verify_authentication( *, credential: dict[str, Any], expected_challenge: str, expected_rp_id: str, expected_origins: list[str], credential_public_key: str, credential_current_sign_count: int, ): parsed = parse_authentication_credential_json(credential) verified = verify_authentication_response( credential=parsed, expected_challenge=base64url_to_bytes(expected_challenge), expected_rp_id=expected_rp_id, expected_origin=expected_origins, credential_public_key=base64url_to_bytes(credential_public_key), credential_current_sign_count=int(credential_current_sign_count or 0), require_user_verification=True, ) return parsed, verified def get_credential_transports(credential: dict[str, Any]) -> str: response = credential.get("response") if isinstance(credential, dict) else None transports = response.get("transports") if isinstance(response, dict) else None if isinstance(transports, list): normalized = sorted({str(item).strip() for item in transports if str(item).strip()}) return ",".join(normalized) return ""