"""
Provides functions to encrypt and decrypt files using AES cipher.

Tip:
    The name ``encryptor`` or something like that sounds more appropriate for
    the name of the module and the functions, but the damage is done already.

The Header
----------

The header is used to store the important bits of data that will be used to
identify and/or decrypt the encrypted file.

This is the structure of the header of an encrypted file:

+----------------------------------------+
|  Header Format (Big endian; 118 bytes) |
+========================================+
| Magic number (``I``)                   |
+----------------------------------------+
| Mode Value (``H``)                     |
+----------------------------------------+
| Nonce (``16s``)                        |
+----------------------------------------+
| Authentication Tag (``32s``)           |
+----------------------------------------+
| Metadata (``32s``)                     |
+----------------------------------------+
| Key Derivation Function Salt (``32s``) |
+----------------------------------------+

Note:
    The value in the brackets are the corresponding symbols used in ``struct``
    module.

Parts of Header
~~~~~~~~~~~~~~~

The header can be represented as a C struct:

.. code-block:: c

    typedef struct {
        unsigned int magic;
        unsigned short mode;
        char nonce[16];
        char tag[32];
        char metadata[32];
        char salt[32];
    } Header;

- Magic number (``unsigned int magic``):
    A unique number to identify the filetype.

- Mode value (``unsigned short mode``):
    The AES mode used to encrypt the file.

- Nonce (``char nonce[16]``):
    The ``nonce`` or ``initialization vector`` used for the AES cipher.

- Authentication Tag (``char tag[32]``):
    The tag generated by the cipher after the encryption is over.

- Metadata (``char metadata[32]``):
    Any binary data. **Only this can be specified by the user**. The maximum
    possible length of the metadata is defined in :py:const:`MAX_METADATA_LEN`.

- Key Derivation Function Salt (``char salt[32]``):
    The salt used for key derivation.


Operation details
-----------------

Password derivation
~~~~~~~~~~~~~~~~~~~

The ``password`` is first derived into a key with PBKDF2-HMAC with 32 byte
salt, 150000 iterations, ``sha256`` as the hash algorithm.

Cipher creation
~~~~~~~~~~~~~~~

The cipher is created with 12 byte nonce if mode is GCM else 16 byte nonce. The
nonce is stored as a part of ``Header`` for identifying the file, along with
other important values.

Authentication
~~~~~~~~~~~~~~

Before the operation begins, the authentication data is passed to the cipher.
The authentication bits are::

  magic, mode, salt, metadata, nonce

in that order.

Finalization
~~~~~~~~~~~~

After completion of the entire operation, the tag created by the authenticator
of the cipher is written to the file as a part of ``Header``. If the file is
being decrypted, it is read from the ``Header`` for verifying the file
integrity and correct decryption.
"""
from __future__ import annotations

import os
import struct
import typing
from collections import namedtuple
from functools import partial
from hashlib import pbkdf2_hmac

from .ciphers import exc
from .ciphers.backends.symmetric import FileCipherWrapper
from .ciphers.interfaces import AES
from .ciphers.modes import AEAD, SPECIAL, Modes

#: A KDF callable
KDFunc = typing.Callable[[bytes, bytes, int], bytes]

if typing.TYPE_CHECKING:  # pragma: no cover
    import io

    from .ciphers.backends import Backends

#: Maximum possible length of the metadata.
MAX_METADATA_LEN = 32

#: Maximum length of authentication tag.
MAX_TAG_LEN = 32

#: Maximum length of password derivation salt.
MAX_SALT_LEN = 32

#: Maximum length of AES cipher's nonce.
MAX_NONCE_LEN = 16

#: A struct that represents the data that is written to the encrypted
#: file as its header.
HEADER_PAYLOAD = struct.Struct(
    f">I H {MAX_NONCE_LEN}s {MAX_TAG_LEN}s {MAX_METADATA_LEN}s {MAX_SALT_LEN}s"
)

#: A struct that represents the data that is passed to the cipher's
#: authenticator.
AUTHENTICATION_PAYLOAD = struct.Struct(
    f">I H {MAX_NONCE_LEN}s {MAX_METADATA_LEN}s {MAX_SALT_LEN}s"
)

#: The magic number of the encrypted file.
MAGIC = 0xC8E52E4A

#: The default key derivation function. ``PBKDF2-HMAC-SHA256-150000`` is used
#: ``PyFLocker``.
PBKDF2_HMAC = partial(pbkdf2_hmac, hash_name="sha256", iterations=150000)

#: The default metadata.
METADATA = b"CREATED BY: PyFLocker"

#: Default extension of the encrypted file.
EXTENSION = ".pyflk"

_Header = namedtuple("_Header", "magic mode nonce tag metadata salt")


def encryptf(
    infile: io.BufferedReader,
    outfile: typing.IO[bytes],
    password: bytes,
    *,
    kdf: KDFunc | None = None,
    aes_mode: Modes = Modes.MODE_GCM,
    blocksize: int = 16 * 1024,
    metadata: bytes = METADATA,
    dklen: int = 32,
    backend: Backends | None = None,
) -> None:
    """Encrypts the binary data using AES cipher and writes it to ``outfile``.

    Args:
        infile: The binary stream to read from.
        outfile: The binary stream to write the encrypted bytes into.
        password: Password to use to encrypt the binary data.

    Keyword Arguments:
        kdf:
            The key derivation function to use. It must be a callable that
            accepts 3 keyword arguments: ``password``, ``salt`` and ``dklen``.
            If ``kdf`` is ``None``, ``PBKDF2-HMAC-SHA256-150000`` is used
            instead.
        aes_mode:
            The AES mode to use for encryption/decryption. The mode can be any
            attribute from :any:`Modes` except those which are defined in
            :any:`modes.SPECIAL`. Defaults to :any:`Modes.MODE_GCM`. The AES
            mode is stored as a part of the encrypted file.
        blocksize:
            The amount of data to read from ``infile`` in each iteration.
            Defalts to 16384.
        metadata:
            The metadata to write to the file. It must be up-to 32 bytes.
        dklen:
            The desired key length (in bytes) for passing to the cipher. It
            specifies the strength of AES cipher. Defaults to 32.
        backend:
            The backend to use to instantiate the AES cipher from. If ``None``
            is specified (the default), any available backend will be used.

    Raises:
        ValueError: If ``infile`` and ``outfile`` point to the same file.
        NotImplementedError:
            Raised if ``aes_mode`` is not amongst the supported modes.
        OverflowError:
            Raised if length of metadata exceeded :py:const:`MAX_METADATA_LEN`.
    """
    _assert_unique_files(infile, outfile)

    if aes_mode in SPECIAL:
        raise NotImplementedError(f"{aes_mode} is not supported.")

    if len(metadata) > MAX_METADATA_LEN:
        raise OverflowError("maximum metadata length exceeded (limit: 32).")

    # create the salt and nonce...
    salt = os.urandom(32)
    nonce = os.urandom(12) if aes_mode == AES.MODE_GCM else os.urandom(16)

    # ...and pack it into header and write it to the outfile
    header = _Header(MAGIC, aes_mode.value, nonce, b"", metadata, salt)
    outfile.write(HEADER_PAYLOAD.pack(*header))

    # Derive the key with the key derivation function.
    if kdf is None:
        kdf = PBKDF2_HMAC
    key = kdf(
        password=password,  # type: ignore
        salt=salt,
        dklen=_check_key_length(dklen),
    )

    # create a cipher with the key
    cipher = AES.new(
        True,
        key,
        Modes(header.mode),
        header.nonce,
        file=infile,
        backend=backend,
        tag_length=None,
    )
    assert isinstance(cipher, FileCipherWrapper)

    # authenticate the payload
    cipher.authenticate(
        AUTHENTICATION_PAYLOAD.pack(
            header.magic,
            header.mode,
            nonce,
            metadata,
            salt,
        )
    )

    cipher.update_into(outfile, blocksize=blocksize)

    # put the tag back in the header
    outfile.seek(struct.calcsize(f">I H {MAX_NONCE_LEN}s"))
    outfile.write(cipher.calculate_tag())  # type: ignore


def decryptf(
    infile: io.BufferedReader,
    outfile: typing.IO[bytes],
    password: bytes,
    *,
    kdf: KDFunc | None = None,
    blocksize: int = 16 * 1024,
    metadata: bytes = METADATA,
    dklen: int = 32,
    backend: Backends | None = None,
) -> None:
    """Decrypts the binary data using AES cipher and writes it to ``outfile``.

    Args:
        infile: The binary stream to read from.
        outfile: The binary stream to write the decrypted bytes into.
        password: Password to use to decrypt the binary data.

    Keyword Arguments:
        kdf:
            The key derivation function to use. It must be a callable that
            accepts 3 keyword arguments: ``password``, ``salt`` and ``dklen``.
            If ``kdf`` is ``None``, ``PBKDF2-HMAC-SHA256-150000`` is used
            instead.
        blocksize:
            The amount of data to read from ``infile`` in each iteration.
            Defalts to 16384.
        metadata:
            The metadata to write to the file. It must be up-to 32 bytes.
        dklen:
            The desired key length (in bytes) for passing to the cipher. It
            specifies the strength of AES cipher. Defaults to 32.
        backend:
            The backend to use to instantiate the AES cipher from. If ``None``
            is specified (the default), any available backend will be used.

    Raises:
        ValueError: If ``infile`` and ``outfile`` point to the same file.
        TypeError: If the header data is incorrect.
        DecryptionError: If the decryption fails.
    """
    _assert_unique_files(infile, outfile)

    # extract the header from the file
    header = _get_header(infile.read(HEADER_PAYLOAD.size), metadata)

    # Derive the key with the key derivation function.
    if kdf is None:
        kdf = PBKDF2_HMAC
    key = kdf(
        password=password,  # type: ignore
        salt=header.salt,
        dklen=_check_key_length(dklen),
    )

    # create a cipher with the key
    key = kdf(
        password=password,  # type: ignore
        salt=header.salt,
        dklen=_check_key_length(dklen),
    )
    cipher = AES.new(
        False,
        key,
        Modes(header.mode),
        header.nonce,
        file=infile,
        backend=backend,
        tag_length=None,
    )
    assert isinstance(cipher, FileCipherWrapper)

    # authenticate the payload
    cipher.authenticate(
        AUTHENTICATION_PAYLOAD.pack(
            header.magic,
            header.mode,
            header.nonce,
            metadata,
            header.salt,
        )
    )

    cipher.update_into(outfile, blocksize=blocksize, tag=header.tag)


def encrypt(
    infile: str | os.PathLike,
    outfile: str | os.PathLike,
    password: bytes,
    remove: bool = True,
    **kwargs: typing.Any,
) -> None:
    """
    Read from the file specified by the file-path ``infile`` and encrypt and
    write its contents to path specified by ``outfile``.

    Args:
        infile: The file path to read the data from.
        outfile:
            The file path to write the data to. The file should not already
            exist in the designated location.
        password: Password to use to encrypt the file.
        remove: Whether to remove the ``infile`` after it has been encrypted.

    Keyword Arguments:
        **kwargs:
            The addtional arguments to pass to :any:`encryptf`. See the
            documentation of :any:`encryptf` for more information.

    Note:
        Any other errors are raised from the :any:`encryptf` itself.

    Important:
        The removal of file is **NOT** secure, because it uses
        :py:func:`os.remove` to remove the file. With enough expertise and
        time, the original file can be restored. If you want to remove the
        original file securely, consider using ``shred`` or ``srm`` or some
        other secure file deletion tools.
    """
    _encrypt_or_decrypt(
        encryptf,
        infile,
        outfile,
        password,
        remove,
        **kwargs,
    )


def decrypt(
    infile: str | os.PathLike,
    outfile: str | os.PathLike,
    password: bytes,
    remove: bool = True,
    **kwargs: typing.Any,
) -> None:
    """
    Read from the file specified by the file-path ``infile`` and decrypt and
    write its contents to path specified by ``outfile``.

    Args:
        infile: The file path to read the data from.
        outfile:
            The file path to write the data to. The file should not already
            exist in the designated location.
        password: Password to use to decrypt the file.
        remove: Whether to remove the ``infile`` after it has been decrypted.

    Keyword Arguments:
        **kwargs:
            The addtional arguments to pass to :any:`decryptf`. See the
            documentation of :any:`decryptf` for more information.

    Note:
        Any other errors are raised from the :any:`decryptf` itself.

    Important:
        The removal of file is **NOT** secure, because it uses
        :py:func:`os.remove` to remove the file. With enough expertise and
        time, the original file can be restored. If you want to remove the
        original file securely, consider using ``shred`` or ``srm`` or some
        other secure file deletion tools.
    """
    _encrypt_or_decrypt(
        decryptf,
        infile,
        outfile,
        password,
        remove,
        **kwargs,
    )


def lockerf(
    infile: io.BufferedReader,
    outfile: typing.IO[bytes],
    password: bytes,
    encrypting: bool,
    **kwargs: typing.Any,
) -> None:
    """Utility tool for encrypting files.

    This function reads from ``infile`` in blocks, specified by ``blocksize``,
    encrypts or decrypts the data and writes to ``outfile``. By design of
    the cipher wrapper for R/W to files, no intermediate copy of data is
    made during operation.


    Args:
        infile: The binary stream to read from.
        outfile: The binary stream to write the encrypted/decrypted bytes into.
        password: Password to use to encrypt/decrypt the binary data.
        encrypting:
            Whether the ``infile`` is being encrypted: True if encrypting else
            False.

    Keyword Arguments:
        **kwargs:
            The addtional arguments to pass to :any:`encryptf` or
            :any:`decryptf`. See their documentation more information.

    Note:
        See documentation of :any:`encryptf` and :any:`decryptf` for possible
        errors.
    """
    if encrypting:
        encryptf(infile, outfile, password, **kwargs)
    else:
        kwargs.pop("aes_mode", None)
        decryptf(infile, outfile, password, **kwargs)


def locker(
    file: str | os.PathLike[str],
    password: bytes,
    encrypting: bool | None = None,
    remove: bool = True,
    *,
    ext: str | None = None,
    newfile: str | os.PathLike[str] | None = None,
    **kwargs: typing.Any,
) -> None:
    """Encrypts or decrypts files with AES algorithm.

    Args:
        file: The actual location of the file.
        password: Password to use to encrypt/decrypt the file.
        encrypting:
            Whether the file is being locked (encrypted) or not.

            If ``encrypting`` is True, the file is encrypted no matter what
            the extension is.
            If ``encrypting`` is False, the file is decrypted no matter what
            the extension is.

            If ``encrypting`` is None (the default), it is guessed from the
            file extension and the file header instead.

            If encrypting is provided, argument ``ext`` is ignored.
        remove:
            Whether to remove the file after encryption/decryption. Default is
            ``True``.

    Keyword Arguments:
        ext:
            The extension to be used for the encrypted file. If ``None``, the
            default value :py:const:`EXTENSION` is used.
        newfile:
            The name of the file to be created. It must not be already present.
            If None is provided (default), the name of the ``file`` plus the
            extension is used.
        **kwargs:
            The addtional arguments to pass to :any:`encryptf` or
            :any:`decryptf`. See their documentation for more information.

    Note:
        See documentation of :any:`encryptf` and :any:`decryptf` for possible
        errors.

    Important:
        The removal of file is **NOT** secure, because it uses
        :py:func:`os.remove` to remove the file. With enough expertise and
        time, the original file can be restored. If you want to remove the
        original file securely, consider using ``shred`` or ``srm`` or some
        other secure file deletion tools.
    """
    if newfile and ext:
        raise ValueError("newfile and ext are mutually exclusive")

    ext = ext or EXTENSION
    file = os.fspath(file)

    # guess encrypting if not provided
    if encrypting is None:
        encrypting = not file.endswith(ext)

    # make newfile name if not provided
    if newfile is None:
        if encrypting:
            newfile = file + ext
        else:
            newfile = os.path.splitext(file)[0]

    if encrypting:
        encrypt(file, newfile, password, remove, **kwargs)
    else:
        kwargs.pop("aes_mode", None)
        decrypt(file, newfile, password, remove, **kwargs)


def extract_header_from_file(
    path: str | os.PathLike,
    metadata: bytes = METADATA,
) -> _Header:
    """Extract the header from the file ``file``.

    Args:
        path: The path to the encrypted file.
        metadata: The metadata written to the file as a part of the header.

    Returns:
        The header data.
    """
    with open(path, "rb") as file:
        return _get_header(file.read(HEADER_PAYLOAD.size), metadata)


def _assert_unique_files(
    infile: typing.IO[bytes],
    outfile: typing.IO[bytes],
) -> None:
    """Check if files are unique, else raise ValueError."""
    if os.path.samefile(infile.fileno(), outfile.fileno()):
        raise ValueError("infile and outfile are the same")


def _encrypt_or_decrypt(
    callable: typing.Callable[..., None],
    infile: str | os.PathLike,
    outfile: str | os.PathLike,
    password: bytes,
    remove: bool = True,
    **kwargs: typing.Any,
) -> None:
    try:
        with open(infile, "rb") as fin, open(outfile, "xb") as fout:
            callable(fin, fout, password, **kwargs)
    except (TypeError, exc.DecryptionError):
        # remove invalid file
        os.remove(outfile)
        raise
    else:
        # remove the original file
        if remove:
            os.remove(infile)


def _check_key_length(n: int) -> int:
    if n in (128, 192, 256):
        return n // 8
    elif n in (16, 24, 32):
        return n
    else:
        raise ValueError("invalid key length")


def _get_header(data: bytes, metadata: bytes = METADATA) -> _Header:
    try:
        (
            magic,
            mode,
            nonce,
            tag,
            metadata_h,
            salt,
        ) = HEADER_PAYLOAD.unpack(data)
    except struct.error as e:
        raise TypeError("The file format is invalid (Header mismatch).") from e

    if (
        magic != MAGIC
        or metadata != metadata_h[: len(metadata) - MAX_METADATA_LEN]
    ):
        raise TypeError(
            "The file format is invalid (Metadata/magic number mismatch)."
        )

    if mode == Modes.MODE_GCM.value:
        nonce = nonce[:12]
    if Modes(mode) in AEAD:
        tag = tag[:16]
    return _Header(magic, mode, nonce, tag, metadata, salt)
