94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
from argon2 import PasswordHasher
|
|
from argon2.exceptions import VerifyMismatchError
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy.orm import Session
|
|
|
|
from core.config import settings
|
|
from core.db import get_db
|
|
|
|
_ph = PasswordHasher()
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login", auto_error=False)
|
|
|
|
|
|
def hash_password(pw: str) -> str:
|
|
return _ph.hash(pw)
|
|
|
|
|
|
def verify_password(pw: str, hashed: str) -> bool:
|
|
try:
|
|
return _ph.verify(hashed, pw)
|
|
except VerifyMismatchError:
|
|
return False
|
|
|
|
|
|
def _make_token(sub: str, role: str, delta: timedelta, token_type: str) -> str:
|
|
exp = datetime.now(UTC) + delta
|
|
payload = {"sub": sub, "role": role, "type": token_type, "exp": exp}
|
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256")
|
|
|
|
|
|
def make_access_token(user_id: int, role: str) -> str:
|
|
return _make_token(str(user_id), role, timedelta(minutes=settings.JWT_ACCESS_MINUTES), "access")
|
|
|
|
|
|
def make_refresh_token(user_id: int, role: str) -> str:
|
|
return _make_token(str(user_id), role, timedelta(days=settings.JWT_REFRESH_DAYS), "refresh")
|
|
|
|
|
|
def decode_token(token: str) -> dict[str, Any]:
|
|
try:
|
|
return jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
|
|
except JWTError as e:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}") from e
|
|
|
|
|
|
def current_user_claims(token: str | None = Depends(oauth2_scheme)) -> dict[str, Any]:
|
|
if not token:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
|
|
claims = decode_token(token)
|
|
if claims.get("type") != "access":
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token type")
|
|
return claims
|
|
|
|
|
|
def require_admin(claims: dict = Depends(current_user_claims)) -> dict:
|
|
if claims.get("role") != "admin":
|
|
raise HTTPException(status.HTTP_403_FORBIDDEN, "Admin role required")
|
|
return claims
|
|
|
|
|
|
def optional_user(token: str | None = Depends(oauth2_scheme)) -> dict | None:
|
|
if not token:
|
|
return None
|
|
try:
|
|
claims = decode_token(token)
|
|
return claims if claims.get("type") == "access" else None
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
def get_current_user_id(claims: dict = Depends(current_user_claims)) -> int:
|
|
return int(claims["sub"])
|
|
|
|
|
|
# Re-export for DI-free apps
|
|
__all__ = [
|
|
"hash_password",
|
|
"verify_password",
|
|
"make_access_token",
|
|
"make_refresh_token",
|
|
"decode_token",
|
|
"current_user_claims",
|
|
"require_admin",
|
|
"optional_user",
|
|
"get_current_user_id",
|
|
"oauth2_scheme",
|
|
"get_db",
|
|
"Session",
|
|
]
|