from __future__ import annotations

from .comparison import Comparison
from .comparison_library_utils import (
    comparison_at_thresholds_error_logger,
    datediff_error_logger,
)
from .misc import ensure_is_iterable


class ExactMatchBase(Comparison):
    def __init__(
        self,
        col_name,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_else=None,
        include_colname_in_charts_label=False,
    ):
        """A comparison of the data in `col_name` with two levels:
        - Exact match
        - Anything else

        Args:
            col_name (str): The name of the column to compare
            term_frequency_adjustments (bool, optional): If True, term frequency
                adjustments will be made on the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.
            include_colname_in_charts_label: If true, append col name to label for
                charts.  Defaults to False.

        Returns:
            Comparison: A comparison that can be inclued in the Splink settings
                dictionary
        """

        comparison_dict = {
            "comparison_description": "Exact match vs. anything else",
            "comparison_levels": [
                self._null_level(col_name),
                self._exact_match_level(
                    col_name,
                    term_frequency_adjustments=term_frequency_adjustments,
                    m_probability=m_probability_exact_match,
                    include_colname_in_charts_label=include_colname_in_charts_label,
                ),
                self._else_level(m_probability=m_probability_else),
            ],
        }
        super().__init__(comparison_dict)


class DistanceFunctionAtThresholdsComparisonBase(Comparison):
    def __init__(
        self,
        col_name: str,
        distance_function_name: str,
        distance_threshold_or_thresholds: int | list,
        higher_is_more_similar: bool = True,
        include_exact_match_level=True,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_lev: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in `col_name` with a user-provided distance
        function used to assess middle similarity levels.

        The user-provided distance function must exist in the SQL backend.

        An example of the output with default arguments and setting
        `distance_function_name` to `jaccard` and
        `distance_threshold_or_thresholds = [0.9,0.7]` would be
        - Exact match
        - Jaccard distance <= 0.9
        - Jaccard distance <= 0.7
        - Anything else

        Args:
            col_name (str): The name of the column to compare
            distance_function_name (str): The name of the distance function.
            distance_threshold_or_thresholds (Union[int, list], optional): The
                threshold(s) to use for the middle similarity level(s).
                Defaults to [1, 2].
            higher_is_more_similar (bool): If True, a higher value of the distance
                function indicates a higher similarity (e.g. jaro_winkler).
                If false, a higher value indicates a lower similarity
                (e.g. levenshtein).
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            term_frequency_adjustments (bool, optional): If True, apply term frequency
                adjustments to the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_lev (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the thresholds specified. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison:
        """

        distance_thresholds = ensure_is_iterable(distance_threshold_or_thresholds)

        if m_probability_or_probabilities_lev is None:
            m_probability_or_probabilities_lev = [None] * len(distance_thresholds)
        m_probabilities = ensure_is_iterable(m_probability_or_probabilities_lev)

        # Validate user inputs
        comparison_at_thresholds_error_logger("distance_function", distance_thresholds)

        comparison_levels = []
        comparison_levels.append(self._null_level(col_name))
        if include_exact_match_level:
            level = self._exact_match_level(
                col_name,
                term_frequency_adjustments=term_frequency_adjustments,
                m_probability=m_probability_exact_match,
            )
            comparison_levels.append(level)

        for thres, m_prob in zip(distance_thresholds, m_probabilities):
            # these function arguments hold for all cases.
            kwargs = dict(
                col_name=col_name,
                distance_threshold=thres,
                m_probability=m_prob,
            )
            # separate out the two that are only used
            # when we have a user-supplied function, rather than a predefined subclass
            # feels a bit hacky, but will do at least for time being
            if not self._is_distance_subclass:
                kwargs["distance_function_name"] = distance_function_name
                kwargs["higher_is_more_similar"] = higher_is_more_similar
            level = self._distance_level(**kwargs)
            comparison_levels.append(level)

        comparison_levels.append(
            self._else_level(m_probability=m_probability_else),
        )

        comparison_desc = ""
        if include_exact_match_level:
            comparison_desc += "Exact match vs. "

        thres_desc = ", ".join([str(d) for d in distance_thresholds])
        plural = "" if len(distance_thresholds) == 1 else "s"
        comparison_desc += (
            f"{distance_function_name} at threshold{plural} {thres_desc} vs. "
        )
        comparison_desc += "anything else"

        comparison_dict = {
            "comparison_description": comparison_desc,
            "comparison_levels": comparison_levels,
        }
        super().__init__(comparison_dict)

    @property
    def _is_distance_subclass(self):
        return False


class LevenshteinAtThresholdsComparisonBase(DistanceFunctionAtThresholdsComparisonBase):
    def __init__(
        self,
        col_name: str,
        distance_threshold_or_thresholds: int | list = [1, 2],
        include_exact_match_level=True,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_lev: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in `col_name` with the levenshtein distance used to
        assess middle similarity levels.

        An example of the output with default arguments and setting
        `distance_threshold_or_thresholds = [1,2]` would be
        - Exact match
        - levenshtein distance <= 1
        - levenshtein distance <= 2
        - Anything else

        Args:
            col_name (str): The name of the column to compare
            distance_threshold_or_thresholds (Union[int, list], optional): The
                threshold(s) to use for the middle similarity level(s).
                Defaults to [1, 2].
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            term_frequency_adjustments (bool, optional): If True, apply term frequency
                adjustments to the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_lev (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the thresholds specified for given function. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison:
        """

        super().__init__(
            col_name,
            self._levenshtein_name,
            distance_threshold_or_thresholds,
            False,
            include_exact_match_level,
            term_frequency_adjustments,
            m_probability_exact_match,
            m_probability_or_probabilities_lev,
            m_probability_else,
        )

    @property
    def _is_distance_subclass(self):
        return True


class JaccardAtThresholdsComparisonBase(DistanceFunctionAtThresholdsComparisonBase):
    def __init__(
        self,
        col_name: str,
        distance_threshold_or_thresholds: int | list = [0.9, 0.7],
        include_exact_match_level=True,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_lev: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in `col_name` with the jaccard distance used to
        assess middle similarity levels.

        An example of the output with default arguments and setting
        `distance_threshold_or_thresholds = [1,2]` would be
        - Exact match
        - Jaccard distance <= 0.9
        - Jaccard distance <= 0.7
        - Anything else

        Args:
            col_name (str): The name of the column to compare
            distance_threshold_or_thresholds (Union[int, list], optional): The
                threshold(s) to use for the middle similarity level(s).
                Defaults to [0.9, 0.7].
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            term_frequency_adjustments (bool, optional): If True, apply term frequency
                adjustments to the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_lev (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the thresholds specified for given function. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison:
        """

        super().__init__(
            col_name,
            self._jaccard_name,
            distance_threshold_or_thresholds,
            True,
            include_exact_match_level,
            term_frequency_adjustments,
            m_probability_exact_match,
            m_probability_or_probabilities_lev,
            m_probability_else,
        )

    @property
    def _is_distance_subclass(self):
        return True


class JaroWinklerAtThresholdsComparisonBase(DistanceFunctionAtThresholdsComparisonBase):
    def __init__(
        self,
        col_name: str,
        distance_threshold_or_thresholds: int | list = [0.9, 0.7],
        include_exact_match_level=True,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_lev: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in `col_name` with the jaro_winkler distance used to
        assess middle similarity levels.

        An example of the output with default arguments and setting
        `distance_threshold_or_thresholds = [1,2]` would be
        - Exact match
        - jaro_winkler distance <= 0.9
        - jaro_winkler distance <= 0.7
        - Anything else

        Args:
            col_name (str): The name of the column to compare
            distance_threshold_or_thresholds (Union[int, list], optional): The
                threshold(s) to use for the middle similarity level(s).
                Defaults to [0.9, 0.7].
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            term_frequency_adjustments (bool, optional): If True, apply term frequency
                adjustments to the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_lev (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the thresholds specified for given function. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison:
        """

        super().__init__(
            col_name,
            self._jaro_winkler_name,
            distance_threshold_or_thresholds,
            True,
            include_exact_match_level,
            term_frequency_adjustments,
            m_probability_exact_match,
            m_probability_or_probabilities_lev,
            m_probability_else,
        )

    @property
    def _is_distance_subclass(self):
        return True


class ArrayIntersectAtSizesComparisonBase(Comparison):
    def __init__(
        self,
        col_name: str,
        size_or_sizes: int | list = [1],
        m_probability_or_probabilities_sizes: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in array column `col_name` with various
        intersection sizes to assess similarity levels.

        An example of the output with default arguments and setting
        `size_or_sizes = [3, 1]` would be
        - Intersection has at least 3 elements
        - Intersection has at least 1 element (i.e. 1 or 2)
        - Anything else (i.e. empty intersection)

        Args:
            col_name (str): The name of the column to compare
            size_or_sizes (Union[int, list], optional): The size(s) of intersection
                to use for the non-'else' similarity level(s). Should be in
                descending order. Defaults to [1].
            m_probability_or_probabilities_sizes (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the sizes specified. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison:
        """

        sizes = ensure_is_iterable(size_or_sizes)
        if len(sizes) == 0:
            raise ValueError(
                "`size_or_sizes` must have at least one element, so that Comparison "
                "has more than just an 'else' level"
            )
        if any(size <= 0 for size in sizes):
            raise ValueError("All entries of `size_or_sizes` must be postive")

        if m_probability_or_probabilities_sizes is None:
            m_probability_or_probabilities_sizes = [None] * len(sizes)
        m_probabilities = ensure_is_iterable(m_probability_or_probabilities_sizes)

        comparison_levels = []
        comparison_levels.append(self._null_level(col_name))

        for size_intersect, m_prob in zip(sizes, m_probabilities):
            level = self._array_intersect_level(
                col_name, m_probability=m_prob, min_intersection=size_intersect
            )
            comparison_levels.append(level)

        comparison_levels.append(
            self._else_level(m_probability=m_probability_else),
        )

        comparison_desc = ""

        size_desc = ", ".join([str(s) for s in sizes])
        plural = "" if len(sizes) == 1 else "s"
        comparison_desc += (
            f"Array intersection at minimum size{plural} {size_desc} vs. "
        )
        comparison_desc += "anything else"

        comparison_dict = {
            "comparison_description": comparison_desc,
            "comparison_levels": comparison_levels,
        }
        super().__init__(comparison_dict)

    @property
    def _array_intersect_level(self):
        raise NotImplementedError("Intersect level not defined on base class")


class DateDiffAtThresholdsComparisonBase(Comparison):
    def __init__(
        self,
        col_name: str,
        date_thresholds: int | list = [1],
        date_metrics: str | list = ["year"],
        include_exact_match_level=True,
        term_frequency_adjustments=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_sizes: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the data in the date column `col_name` with various
        date thresholds and metrics to assess similarity levels.

        An example of the output with default arguments and settings
        `date_thresholds = [1]` and `date_metrics = ['day']` would be
        - The two input dates are within 1 day of one another
        - Anything else (i.e. all other dates lie outside this range)

        `date_thresholds` and `date_metrics` should be used in conjunction
        with one another.
        For example, `date_thresholds = [10, 12, 15]` with
        `date_metrics = ['day', 'month', 'year']` would result in the following checks:
        - The two dates are within 10 days of one another
        - The two dates are within 12 months of one another
        - And the two dates are within 15 years of one another

        Args:
            col_name (str): The name of the date column to compare.
            date_thresholds (Union[int, list], optional): The size(s) of given date
                thresholds, to assess whether two dates fall within a given time
                interval.
                These values can be any integer value and should be used in tandem with
                `date_metrics`.
            date_metrics (Union[str, list], optional): The unit of time you wish your
                `date_thresholds` to be measured against.
                Metrics should be one of `day`, `month` or `year`.
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            term_frequency_adjustments (bool, optional): If True, apply term frequency
                adjustments to the exact match level. Defaults to False.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_sizes (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the sizes specified. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison: A comparison that can be inclued in the Splink settings
                dictionary.
        """

        thresholds = ensure_is_iterable(date_thresholds)
        metrics = ensure_is_iterable(date_metrics)

        # Validate user inputs
        comparison_at_thresholds_error_logger("datediff", date_thresholds)
        datediff_error_logger(thresholds, metrics)

        if m_probability_or_probabilities_sizes is None:
            m_probability_or_probabilities_sizes = [None] * len(thresholds)
        m_probabilities = ensure_is_iterable(m_probability_or_probabilities_sizes)

        comparison_levels = []
        comparison_levels.append(self._null_level(col_name))
        if include_exact_match_level:
            level = self._exact_match_level(
                col_name,
                term_frequency_adjustments=term_frequency_adjustments,
                m_probability=m_probability_exact_match,
            )
            comparison_levels.append(level)

        for date_thres, date_metr, m_prob in zip(thresholds, metrics, m_probabilities):
            level = self._datediff_level(
                col_name,
                date_threshold=date_thres,
                date_metric=date_metr,
                m_probability=m_prob,
            )
            comparison_levels.append(level)

        comparison_levels.append(
            self._else_level(m_probability=m_probability_else),
        )

        comparison_desc = ""
        if include_exact_match_level:
            comparison_desc += "Exact match vs. "

        thres_desc = ", ".join(
            [f"{m.title()}(s): {v}" for m, v in zip(metrics, thresholds)]
        )
        plural = "" if len(thresholds) == 1 else "s"
        comparison_desc += (
            f"Dates within the following threshold{plural} {thres_desc} vs. "
        )
        comparison_desc += "anything else"

        comparison_dict = {
            "comparison_description": comparison_desc,
            "comparison_levels": comparison_levels,
        }
        super().__init__(comparison_dict)

    @property
    def _datediff_level(self):
        raise NotImplementedError("Datediff level not defined on base class")


class DistanceInKMAtThresholdsComparisonBase(Comparison):
    def __init__(
        self,
        lat_col: str,
        long_col: str,
        km_thresholds: int | list = [0.1, 1],
        include_exact_match_level=False,
        m_probability_exact_match=None,
        m_probability_or_probabilities_lev: float | list = None,
        m_probability_else=None,
    ):
        """A comparison of the coordinates defined in 'lat_col' and
        'long col' giving the haversine distance between them in km.

        An example of the output with default arguments and settings
        `km_thresholds = [1]` would be
        - The two coordinates within 1 km of one another
        - Anything else (i.e.  the distance between all coordinate lie outside
        this range)

        Args:
            col_name (str): The name of the date column to compare.
            lat_col (str): The name of the column containing the lattitude of the
                coordinates.
            long_col (str): The name of the column containing the longitude of the
                coordinates.
            km_thresholds (Union[int, list], optional): The size(s) of given date
                thresholds, to assess whether two coordinates fall within a given
                distance.
            include_exact_match_level (bool, optional): If True, include an exact match
                level. Defaults to True.
            m_probability_exact_match (_type_, optional): If provided, overrides the
                default m probability for the exact match level. Defaults to None.
            m_probability_or_probabilities_lev (Union[float, list], optional):
                _description_. If provided, overrides the default m probabilities
                for the sizes specified. Defaults to None.
            m_probability_else (_type_, optional): If provided, overrides the
                default m probability for the 'anything else' level. Defaults to None.

        Returns:
            Comparison: A comparison that can be inclued in the Splink settings
                dictionary.
        """

        thresholds = ensure_is_iterable(km_thresholds)

        if m_probability_or_probabilities_lev is None:
            m_probability_or_probabilities_sizes = [None] * len(thresholds)
        m_probabilities = ensure_is_iterable(m_probability_or_probabilities_sizes)

        comparison_levels = []

        null_level = {
            "sql_condition": f"({lat_col}_l IS NULL OR {lat_col}_r IS NULL) \n"
            f"OR ({long_col}_l IS NULL OR {long_col}_r IS NULL)",
            "label_for_charts": "Null",
            "is_null_level": True,
        }
        comparison_levels.append(null_level)

        if include_exact_match_level:

            label_suffix = f" {lat_col}, {long_col}"
            level = {
                "sql_condition": f"({lat_col}_l = {lat_col}_r) \n"
                f"AND ({long_col}_l = {long_col}_r)",
                "label_for_charts": f"Exact match{label_suffix}",
            }

            if m_probability_exact_match:
                level["m_probability"] = m_probability_exact_match

            comparison_levels.append(level)

        for km_thres, m_prob in zip(km_thresholds, m_probabilities):
            level = self._distance_in_km_level(
                lat_col,
                long_col,
                km_threshold=km_thres,
                m_probability=m_prob,
            )
            comparison_levels.append(level)

        comparison_levels.append(
            self._else_level(m_probability=m_probability_else),
        )

        comparison_desc = ""
        if include_exact_match_level:
            comparison_desc += "Exact match vs. "

        thres_desc = ", ".join([f"Km threshold(s): {thres}" for thres in thresholds])
        plural = "" if len(thresholds) == 1 else "s"
        comparison_desc += (
            f"Km distance within the following threshold{plural} {thres_desc} vs. "
        )
        comparison_desc += "anything else"

        comparison_dict = {
            "comparison_description": comparison_desc,
            "comparison_levels": comparison_levels,
        }
        super().__init__(comparison_dict)
