itcloud/backend/src/app/infra/security.py

256 lines
6.8 KiB
Python
Raw Normal View History

2025-12-30 15:36:41 +01:00
"""
Security utilities for authentication and authorization.
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
Design goals:
- Strong password hashing with Argon2id (no bcrypt 72-byte issues, no pre-hash needed)
- Clear token types (access/refresh) with strict validation
- Optional issuer/audience support (safe defaults)
- Token IDs (jti) for future revocation support
- Settings fetched via dependency-friendly getter (no module-level hard binding)
"""
from __future__ import annotations
from dataclasses import dataclass
2025-12-30 14:51:56 +01:00
from datetime import datetime, timedelta, timezone
2025-12-30 15:36:41 +01:00
from typing import Any, Mapping, Optional
from uuid import uuid4
2025-12-30 13:35:19 +01:00
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.infra.config import get_settings
2025-12-30 15:36:41 +01:00
# ---- Password hashing ----
_pwd_context = CryptContext(
schemes=["argon2"],
deprecated="auto",
# You can tune params via passlib config if needed.
# Argon2 parameters (time_cost/memory_cost/parallelism) are handled by passlib's argon2 backend.
)
2025-12-30 13:35:19 +01:00
def hash_password(password: str) -> str:
"""
2025-12-30 15:36:41 +01:00
Hash a password using Argon2id.
2025-12-30 13:35:19 +01:00
Args:
2025-12-30 15:36:41 +01:00
password: Plain text password.
2025-12-30 13:35:19 +01:00
Returns:
2025-12-30 15:36:41 +01:00
Encoded password hash.
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
if not isinstance(password, str) or not password:
raise ValueError("password must be a non-empty string")
return _pwd_context.hash(password)
2025-12-30 13:35:19 +01:00
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
2025-12-30 15:36:41 +01:00
Verify a password against its stored hash.
2025-12-30 13:35:19 +01:00
Args:
2025-12-30 15:36:41 +01:00
plain_password: Plain text password.
hashed_password: Stored password hash.
2025-12-30 13:35:19 +01:00
Returns:
2025-12-30 15:36:41 +01:00
True if matches, False otherwise.
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
if not plain_password or not hashed_password:
return False
try:
return _pwd_context.verify(plain_password, hashed_password)
except Exception:
# Any parsing/format issues should not crash auth flow.
return False
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
# ---- JWT tokens ----
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
TokenType = str # "access" | "refresh"
@dataclass(frozen=True, slots=True)
class TokenClaims:
"""
Standardized JWT claims we issue and validate.
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
sub: str # subject (e.g., user_id)
typ: TokenType # "access" or "refresh"
exp: datetime
iat: datetime
jti: str
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
# Optional hardening
iss: Optional[str] = None
aud: Optional[str] = None
# Extra custom claims may be included in the encoded token, but are not represented here.
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
def _encode_jwt(claims: Mapping[str, Any]) -> str:
settings = get_settings()
return jwt.encode(
dict(claims),
settings.jwt_secret,
algorithm=settings.jwt_algorithm,
)
def _decode_jwt(token: str) -> Optional[dict[str, Any]]:
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
Decode and verify JWT signature + exp using configured algorithm/secret.
Does NOT by itself enforce token type (access/refresh); use decode_* functions.
"""
if not token:
return None
settings = get_settings()
# Optional issuer/audience validation: only enforced if present in settings
options: dict[str, Any] = {
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": False, # iat isn't always required; we set it, but don't hard-fail on clients w/ clock drift
"require_exp": True,
}
kwargs: dict[str, Any] = {"algorithms": [settings.jwt_algorithm]}
# If your settings provide these, they will be enforced.
iss = getattr(settings, "jwt_issuer", None)
aud = getattr(settings, "jwt_audience", None)
if iss:
kwargs["issuer"] = iss
if aud:
kwargs["audience"] = aud
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
try:
payload = jwt.decode(token, settings.jwt_secret, options=options, **kwargs)
if not isinstance(payload, dict):
return None
return payload
except JWTError:
return None
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
def _build_claims(
*,
subject: str,
token_type: TokenType,
ttl: timedelta,
extra: Optional[Mapping[str, Any]] = None,
) -> dict[str, Any]:
if not subject:
raise ValueError("subject must be a non-empty string")
if token_type not in {"access", "refresh"}:
raise ValueError("token_type must be 'access' or 'refresh'")
now = _utcnow()
exp = now + ttl
settings = get_settings()
iss = getattr(settings, "jwt_issuer", None)
aud = getattr(settings, "jwt_audience", None)
claims: dict[str, Any] = {
"sub": subject,
"typ": token_type, # use "typ" (common) to avoid confusion with header "typ"
"iat": int(now.timestamp()),
"exp": int(exp.timestamp()),
"jti": uuid4().hex,
}
if iss:
claims["iss"] = iss
if aud:
claims["aud"] = aud
if extra:
# Prevent overriding reserved claims
reserved = {"sub", "typ", "iat", "exp", "jti", "iss", "aud", "nbf"}
for k, v in extra.items():
if k in reserved:
continue
claims[k] = v
return claims
def create_access_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
Create a signed JWT access token.
2025-12-30 13:35:19 +01:00
Args:
2025-12-30 15:36:41 +01:00
subject: User identifier (user_id) as string.
extra: Optional additional non-reserved claims (e.g., roles).
2025-12-30 13:35:19 +01:00
Returns:
2025-12-30 15:36:41 +01:00
JWT string.
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
settings = get_settings()
ttl = timedelta(seconds=settings.jwt_access_ttl_seconds)
claims = _build_claims(subject=subject, token_type="access", ttl=ttl, extra=extra)
return _encode_jwt(claims)
2025-12-30 13:35:19 +01:00
2025-12-30 15:36:41 +01:00
def create_refresh_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
Create a signed JWT refresh token.
2025-12-30 13:35:19 +01:00
Args:
2025-12-30 15:36:41 +01:00
subject: User identifier (user_id) as string.
extra: Optional additional non-reserved claims.
2025-12-30 13:35:19 +01:00
Returns:
2025-12-30 15:36:41 +01:00
JWT string.
2025-12-30 13:35:19 +01:00
"""
2025-12-30 15:36:41 +01:00
settings = get_settings()
ttl = timedelta(seconds=settings.jwt_refresh_ttl_seconds)
claims = _build_claims(subject=subject, token_type="refresh", ttl=ttl, extra=extra)
return _encode_jwt(claims)
def decode_access_token(token: str) -> Optional[dict[str, Any]]:
"""
Decode and validate an access token (signature/exp + typ=access).
"""
payload = _decode_jwt(token)
if not payload:
return None
if payload.get("typ") != "access":
2025-12-30 13:35:19 +01:00
return None
2025-12-30 15:36:41 +01:00
if not payload.get("sub"):
return None
return payload
def decode_refresh_token(token: str) -> Optional[dict[str, Any]]:
"""
Decode and validate a refresh token (signature/exp + typ=refresh).
"""
payload = _decode_jwt(token)
if not payload:
return None
if payload.get("typ") != "refresh":
return None
if not payload.get("sub"):
return None
return payload
def get_subject(payload: Mapping[str, Any]) -> Optional[str]:
"""
Extract subject (user_id) from decoded payload.
"""
sub = payload.get("sub")
return str(sub) if sub else None