"""Classes for dealing with Transformation Processing."""
from __future__ import annotations

import logging
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import numpy as np
from pandas.core.dtypes.common import is_numeric_dtype

from bitfount.data.types import SemanticType
from bitfount.transformations.base_transformation import (
    MultiColumnOutputTransformation,
    Transformation,
)
from bitfount.transformations.batch_operations import (
    BatchTimeOperation,
    ImageTransformation,
)
from bitfount.transformations.binary_operations import (
    AdditionTransformation,
    ComparisonTransformation,
    DivisionTransformation,
    MultiplicationTransformation,
    SubtractionTransformation,
)
from bitfount.transformations.dataset_operations import (
    CleanDataTransformation,
    NormalizeDataTransformation,
)
from bitfount.transformations.exceptions import (
    InvalidBatchTransformationError,
    MissingColumnReferenceError,
    NotColumnReferenceError,
    TransformationApplicationError,
)
from bitfount.transformations.references import _extract_col_ref
from bitfount.transformations.unary_operations import (
    InclusionTransformation,
    OneHotEncodingTransformation,
)
from bitfount.types import _DataFrameLib, _DataFrameType, _SeriesType
from bitfount.utils import _get_df_library, _get_df_library_type

if TYPE_CHECKING:
    from bitfount.data.schema import TableSchema


logger = logging.getLogger(__name__)


class TransformationProcessor:
    """Processes Transformations on a given dataframe.

    Args:
        transformations: The list of transformations to apply.
        schema: The schema of the data to be transformed.
        col_refs: The set of columns referenced in those transformations.

    Attributes:
        transformations: The list of transformations to apply.
        schema: The schema of the data to be transformed.
        col_refs: The set of columns referenced in those transformations.

    ::: warning

    The Transformation processor does not add any of the newly created columns
    to the Schema. This must be done separately after processing the transformations.

    :::
    """

    def __init__(
        self,
        transformations: List[Transformation],
        schema: Optional[TableSchema] = None,
        col_refs: Optional[Set[str]] = None,
    ):
        self.transformations = transformations
        self.col_refs = set() if col_refs is None else col_refs

        self.schema = schema
        if self.schema is not None:
            self._schema_cont_cols = self.schema.get_feature_names(
                SemanticType.CONTINUOUS
            )
            self._schema_cat_cols = self.schema.get_feature_names(
                SemanticType.CATEGORICAL
            )
        else:
            self._schema_cont_cols = []
            self._schema_cat_cols = []

        self._operators: Dict[Type[Transformation], Callable] = {
            AdditionTransformation: self._do_addition,
            CleanDataTransformation: self._do_clean_data,
            ComparisonTransformation: self._do_comparison,
            DivisionTransformation: self._do_division,
            InclusionTransformation: self._do_inclusion,
            MultiplicationTransformation: self._do_multiplication,
            NormalizeDataTransformation: self._do_normalize_data,
            OneHotEncodingTransformation: self._do_one_hot_encoding,
            SubtractionTransformation: self._do_subtraction,
            ImageTransformation: self._do_image_transformation,
        }

    def transform(self, data: _DataFrameType) -> _DataFrameType:
        """Performs `self.transformations` on `data` sequentially.

        Arguments to an operation are extracted by first checking if they are
        referencing another transformed column by checking for the name attribute.
        If not, we then check if they are referencing a non-transformed column by
        using a regular expression. Finally, if the regex comes back empty we take
        the argument 'as is' e.g. a string, integer, etc. After the transformations
        are complete, finally removes any columns that shouldn't be part of the final
        output.

        Args:
            data: The `pandas` or `koalas` dataframe to be transformed.

        :::warning

        Differing behaviour between `pandas` and `koalas` in this method.
        Pandas **will not** create columns where there is a type mismatch for the
        operation but koalas **will** create a column of `NaNs`/`None`s etc.

        :::
        """
        data_columns = set(data.columns)

        # Check that all referenced columns are present in `data`
        missing_cols = sorted(self.col_refs.difference(data_columns))
        if missing_cols:
            raise MissingColumnReferenceError(
                [f"Reference to non-existent column: {c}" for c in missing_cols]
            )

        # Loop through transformations and perform them sequentially
        application_errors = []
        cols_to_drop = []
        for transformation in self.transformations:
            # Transformation name is always set
            transformation.name = cast(str, transformation.name)

            # Check transformation output doesn't clash with existing column
            if isinstance(transformation, MultiColumnOutputTransformation):
                clashes = data_columns.intersection(transformation.columns)
            else:
                clashes = data_columns.intersection([transformation.name])
            if clashes:
                application_errors.extend(
                    [
                        f"Output column {col_name}, "
                        f"from transformation {transformation.name}, "
                        f"clashes with an existing data column name."
                        for col_name in sorted(clashes)
                    ]
                )
                continue

            # Get operation as a function
            operation = self._operators[type(transformation)]

            # Add column(s) to `cols_to_drop` if it shouldn't end up in dataframe
            if not transformation.output:
                if isinstance(transformation, MultiColumnOutputTransformation):
                    cols_to_drop.extend(transformation.columns)
                else:  # normal transformation type
                    cols_to_drop.append(transformation.name)

            # Attempt to perform transformation. If there is a type mismatch
            # pandas will throw a type error which we catch and skip the
            # transformation. Koalas does not throw TypeErrors
            try:
                data = operation(data, transformation)
            except TypeError as e:
                application_errors.append(
                    f"Unable to apply transformation, skipping: "
                    f"{transformation.name}: {e}"
                )

        # Check all transformations were applied OK
        if application_errors:
            raise TransformationApplicationError(application_errors)

        # Drop any columns that should not be part of the final dataframe
        if cols_to_drop:
            return data.drop(cols_to_drop, axis=1)

        return data

    def batch_transform(
        self, data: np.ndarray, step: Literal["train", "validation"]
    ) -> np.ndarray:
        """Performs batch transformations.

        Args:
            data: The data to be transformed at batch time as a numpy array.

        Raises:
            InvalidBatchTransformationError: If one of the specified transformations
                does not inherit from `BatchTimeOperation`.

        Returns:
            np.ndarray: The transformed data as a numpy array.
        """
        for transformation in self.transformations:
            # Check transformation is a batch time operation
            if not isinstance(transformation, BatchTimeOperation):
                raise InvalidBatchTransformationError(
                    f"{transformation._registry_name} not a batch time operation"
                )

            # Skip the transformation if it does not apply to the step
            if transformation.step != step:
                logger.debug(f"Skipping transformation: {transformation.name}")
                continue

            logger.debug(f"Applying transformation: {transformation.name}")
            # Get operation as a function
            operation = self._operators[type(transformation)]
            data = operation(data, transformation)

        return data

    @staticmethod
    def _apply_arg_conversion(
        data: _DataFrameType, *args: Union[Any, str, Transformation]
    ) -> Union[_SeriesType, List[_SeriesType]]:
        """Applies argument conversion to each of the supplied *args.

        Check, in order, if it is a:
            - Transformation instance (use transformation output column from data)
            - A column reference (use original column from data)
            - Other (use arg as is)

        Args:
            data: Data in a dataframe.
            *args: The args to convert.

        Returns: A single converted arg if only one was supplied, or a list of converted
            args in the same order they were supplied.
        """
        converted = []
        for arg in args:
            # See if it is a transformation
            if isinstance(arg, Transformation):
                converted.append(data[arg.name])
                continue

            # See if it is a column reference
            try:
                col_ref = _extract_col_ref(arg)
                converted.append(data[col_ref])
                continue
            except NotColumnReferenceError:
                pass

            # Otherwise, just use as is
            converted.append(arg)

        # If only one arg supplied, return single arg
        if len(converted) == 1:
            return converted[0]
        # Otherwise return list
        return converted

    @staticmethod
    def _get_list_of_col_refs(
        data: _DataFrameType, cols: Union[str, List[str]], schema_cols: List[str]
    ) -> Tuple[List[str], bool]:
        """Gets the list of actual column references from a list of potentials.

        Generates a list of col names from a potential column list argument
        which could be:
            - a list of column references
            - "all": in which case use schema and data to generate list
            - a single column reference: wrap in a list

        Args:
            data:
                The data containing the columns.
            cols:
                The column reference(s) or "all".
            schema_cols:
                The list of schema columns to use to generate in case of "all".

        Returns:
            A tuple of the list of column references and a boolean indicating whether
            these have been "pre-extracted" (i.e. don't need to be compared to
            COLUMN_REFERENCE).
        """
        pre_extracted = False
        out_cols = cols
        if out_cols == "all":
            # Extract all target columns from schema
            pre_extracted = True
            out_cols = [col for col in schema_cols if col in data.columns]
        elif isinstance(out_cols, str):
            # Wrap single column in iterable
            out_cols = [out_cols]
        return out_cols, pre_extracted

    @staticmethod
    def _do_addition(data: _DataFrameType, t: AdditionTransformation) -> _DataFrameType:
        """Performs addition transformation on `data` and returns it."""
        arg1, arg2 = TransformationProcessor._apply_arg_conversion(data, t.arg1, t.arg2)
        data[t.name] = arg1 + arg2
        return data

    def _do_clean_data(
        self, data: _DataFrameType, t: CleanDataTransformation
    ) -> _DataFrameType:
        """Cleans categorical and continuous columns in the data.

        Replaces infinities and NAs.
        """
        # Get columns
        cols, pre_extracted = self._get_list_of_col_refs(
            data, t.cols, self._schema_cat_cols + self._schema_cont_cols
        )
        if not pre_extracted:
            cols = [_extract_col_ref(col) for col in cols]

        # Clean columns
        for col_ref in cols:
            # categorical columns
            if col_ref in self._schema_cat_cols:
                data[col_ref] = data[col_ref].fillna("nan")
            # continuous columns
            elif col_ref in self._schema_cont_cols:
                data[col_ref] = data[col_ref].replace([np.inf, -np.inf], np.nan)
                data[col_ref] = data[col_ref].fillna(value=0.0)
            else:
                logger.warning(f"{col_ref} not found in Schema. Skipping cleaning.")

        return data

    @staticmethod
    def _do_comparison(
        data: _DataFrameType, t: ComparisonTransformation
    ) -> _DataFrameType:
        """Performs comparison between arg1 and arg2 of comparison transformation."""
        arg1, arg2 = TransformationProcessor._apply_arg_conversion(data, t.arg1, t.arg2)

        arg1 = arg1.to_numpy()
        try:
            # If arg2 is a series-like
            arg2 = arg2.to_numpy()
        except AttributeError:
            # If arg2 is a constant
            arg2 = np.full_like(arg1, fill_value=arg2)
        conditions = [arg1 < arg2, arg1 == arg2, arg1 > arg2]
        choices = [-1, 0, 1]

        data[t.name] = np.select(conditions, choices, default=np.nan).tolist()

        return data

    @staticmethod
    def _do_division(data: _DataFrameType, t: DivisionTransformation) -> _DataFrameType:
        """Performs division transformation on `data` and returns it."""
        arg1, arg2 = TransformationProcessor._apply_arg_conversion(data, t.arg1, t.arg2)
        data[t.name] = arg1 / arg2
        return data

    @staticmethod
    def _do_inclusion(
        data: _DataFrameType, t: InclusionTransformation
    ) -> _DataFrameType:
        # Only arg should be a column name in this transformation
        """Performs inclusion transformation on `data` and returns it."""
        arg = TransformationProcessor._apply_arg_conversion(data, t.arg)
        arg = cast(_SeriesType, arg)
        data[t.name] = arg.str.contains(t.in_str)
        return data

    @staticmethod
    def _do_multiplication(
        data: _DataFrameType, t: MultiplicationTransformation
    ) -> _DataFrameType:
        """Performs multiplication transformation on `data` and returns it."""
        arg1, arg2 = TransformationProcessor._apply_arg_conversion(data, t.arg1, t.arg2)
        data[t.name] = arg1 * arg2
        return data

    def _do_normalize_data(
        self, data: _DataFrameType, t: NormalizeDataTransformation
    ) -> _DataFrameType:
        """Normalizes numeric columns using their mean and stddev.

        Results in mean of 0 and stddev of 1.
        """
        cols, pre_extracted = self._get_list_of_col_refs(
            data, t.cols, self._schema_cont_cols
        )
        if not pre_extracted:
            cols = [_extract_col_ref(col) for col in cols]

        for col_ref in cols:
            if not is_numeric_dtype(data[col_ref]):
                raise TypeError(
                    f'Cannot normalize column "{col_ref}" as it is not numeric.'
                )
            mean = data[col_ref].mean()
            stddev = data[col_ref].std()
            data[col_ref] = (data[col_ref] - mean) / (1e-7 + stddev)

        for col in self._schema_cont_cols:
            # If self._schema_cont_cols is not empty then self.schema cannot be None
            assert self.schema is not None  # nosec
            self.schema.features["continuous"][col].dtype = data[col].dtype

        return data

    @staticmethod
    def _do_one_hot_encoding(
        data: _DataFrameType, t: OneHotEncodingTransformation
    ) -> _DataFrameType:
        """Performs one hot encoding transformation on `data` and returns it."""
        # Get unencoded data and columns
        arg: _SeriesType = TransformationProcessor._apply_arg_conversion(data, t.arg)
        ohe_cols: List[str] = sorted(t.columns)
        n_rows: int = len(arg)
        n_cols: int = len(ohe_cols)

        # Create an appropriately sized dataframe, filled with zeros
        df_lib_type = _get_df_library_type(data)
        df_lib = _get_df_library(data)
        ohe_df: _DataFrameType = df_lib.DataFrame(
            data=np.zeros(shape=(n_rows, n_cols), dtype=np.int8), columns=ohe_cols
        )

        # First, set all values in the unknown column to 1 which correspond to
        # non-null values in the arg Series. These will be marked into the correct
        # column as we find matches, and if no match is found, they should be in
        # this column anyway.
        if df_lib_type == _DataFrameLib.KOALAS:
            # koalas indexes aren't subscriptable so must use a BooleanIndex instead
            not_null_idxs = arg[arg.notnull()].index
            # koalas indexes don't support __iter__ so will break the .loc
            # setting below; instead we must convert it to a numpy array (which
            # still won't work as it doesn't match a protocol pyspark expects) and
            # then to a list. ks.Index.to_list() isn't implemented either.
            #
            # This approach is relatively slow compared to the native pandas
            # implementation and should probably be replaced with something
            # more performant.
            # TODO: [BIT-1042] Increase performance, avoid koalas multi-conversion.
            not_null_idxs = not_null_idxs.to_numpy().tolist()
        else:
            not_null_idxs = arg.index[arg.notnull()]
        ohe_df.loc[not_null_idxs, t.unknown_col] = 1

        # Iterate through the (value, target_col) pairs and set the correct column
        # to 1 for all locations the value is found. Sets the unknown column in
        # the same locations to 0.
        for val, target_col in t.values.items():
            if df_lib_type == _DataFrameLib.KOALAS:
                # Same issues as above
                match_idxs = arg[arg == val].index
                match_idxs = match_idxs.to_numpy().tolist()
            else:
                match_idxs = arg.index[arg == val]
            ohe_df.loc[match_idxs, target_col] = 1
            ohe_df.loc[match_idxs, t.unknown_col] = 0

        if df_lib_type == _DataFrameLib.KOALAS:
            with df_lib.option_context("compute.ops_on_diff_frames", True):  # type: ignore[attr-defined] # Reason: koalas has this function # noqa: B950
                return df_lib.concat(
                    [data, ohe_df], axis=1, sort=True
                )  # sort=True is needed for koalas
        else:
            return df_lib.concat([data, ohe_df], axis=1)

    @staticmethod
    def _do_subtraction(
        data: _DataFrameType, t: SubtractionTransformation
    ) -> _DataFrameType:
        """Performs subtraction transformation on `data` and returns it."""
        arg1, arg2 = TransformationProcessor._apply_arg_conversion(data, t.arg1, t.arg2)
        data[t.name] = arg1 - arg2
        return data

    @staticmethod
    def _do_image_transformation(
        data: np.ndarray, t: ImageTransformation
    ) -> np.ndarray:
        """Performs image transformation on `data` and returns it."""
        logger.debug(f"Applying image transformations: {t.transformations}")
        tfm = t.get_callable()
        return tfm(image=data)["image"]
