import base64
import hashlib
from datetime import datetime, timezone
from typing import Any, Dict, Optional, Union

from jwcrypto import jwk, jwt

from fief.models import Client, User
from fief.schemas.user import UserDB


def get_user_claims(user: Union[UserDB, User]) -> Dict[str, Any]:
    return {
        "sub": str(user.id),
        "email": user.email,
        "tenant_id": str(user.tenant_id),
    }


def generate_id_token(
    signing_key: jwk.JWK,
    host: str,
    client: Client,
    authenticated_at: datetime,
    user: Union[UserDB, User],
    lifetime_seconds: int,
    *,
    nonce: Optional[str] = None,
    code: Optional[str] = None,
    access_token: Optional[str] = None,
    encryption_key: Optional[jwk.JWK] = None,
) -> str:
    """
    Generate an ID Token for an authenticated user.

    It's a signed JWT with claims following the OpenID specification.

    :param signing_key: The JWK to sign the JWT.
    :host: The issuer host.
    :client: The client used to authenticate the user.
    :authenticated_at: Date and time at which the user authenticated.
    :user: The authenticated user.
    :lifetime_seconds: Lifetime of the JWT.
    :nonce: Optional nonce value associated with the authorization request.
    :code: Optional authorization code associated to the ID Token.
    :access_token: Optional access token associated to the ID Token.
    :encryption_key: Optional JWK to further encrypt the signed token.
    In this case, it becomes a Nested JWT, as defined in rfc7519.
    """
    iat = int(datetime.now(timezone.utc).timestamp())
    exp = iat + lifetime_seconds

    claims = {
        **get_user_claims(user),
        "iss": host,
        "aud": [client.client_id],
        "exp": exp,
        "iat": iat,
        "auth_time": int(authenticated_at.timestamp()),
        "azp": client.client_id,
    }

    if nonce is not None:
        claims["nonce"] = nonce
    if code is not None:
        claims["c_hash"] = get_validation_hash(code)
    if access_token is not None:
        claims["at_hash"] = get_validation_hash(access_token)

    signed_token = jwt.JWT(header={"alg": "RS256"}, claims=claims)
    signed_token.make_signed_token(signing_key)

    if encryption_key is not None:
        encrypted_token = jwt.JWT(
            header={"alg": "RSA-OAEP-256", "enc": "A256CBC-HS512"},
            claims=signed_token.serialize(),
        )
        encrypted_token.make_encrypted_token(encryption_key)
        return encrypted_token.serialize()

    return signed_token.serialize()


def get_validation_hash(value: str) -> str:
    """
    Computes a hash value to be embedded in the ID Token, like at_hash and c_hash.

    Specification: https://openid.net/specs/openid-connect-core-1_0.html#toc
    """
    hasher = hashlib.sha256()
    hasher.update(value.encode("utf-8"))
    hash = hasher.digest()

    half_hash = hash[0 : int(len(hash) / 2)]
    # Remove the Base64 padding "==" at the end
    base64_hash = base64.urlsafe_b64encode(half_hash)[:-2]

    return base64_hash.decode("utf-8")
