from __future__ import annotations

import pytest

from daft.dataframe import DataFrame
from daft.expressions import col
from daft.internal.rule_runner import Once, RuleBatch, RuleRunner
from daft.logical.logical_plan import Filter, Join, LogicalPlan
from daft.logical.optimizer import PushDownPredicates
from daft.logical.schema import ExpressionList
from tests.optimizer.conftest import assert_plan_eq


@pytest.fixture(scope="function")
def optimizer() -> RuleRunner[LogicalPlan]:
    return RuleRunner(
        [
            RuleBatch(
                "pred_pushdown",
                Once,
                [
                    PushDownPredicates(),
                ],
            )
        ]
    )


def test_no_pushdown_on_modified_column(optimizer) -> None:
    df = DataFrame.from_pydict({"ints": [i for i in range(3)], "ints_dup": [i for i in range(3)]})
    df = df.with_column(
        "modified",
        col("ints_dup") + 1,
    ).where(col("ints") == col("modified").alias("ints_dup"))

    # Optimizer cannot push down the filter because it uses a column that was projected
    assert_plan_eq(optimizer(df.plan()), df.plan())


def test_filter_pushdown_select(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8)
    optimized = df.where(col("sepal_length") > 4.8).select("sepal_length", "sepal_width")
    assert unoptimized.column_names == ["sepal_length", "sepal_width"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_select_alias(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.select("sepal_length", "sepal_width").where(col("sepal_length").alias("foo") > 4.8)
    optimized = df.where(col("sepal_length").alias("foo") > 4.8).select("sepal_length", "sepal_width")
    assert unoptimized.column_names == ["sepal_length", "sepal_width"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_with_column(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.with_column("foo", col("sepal_length") + 1).where(col("sepal_length") > 4.8)
    optimized = df.where(col("sepal_length") > 4.8).with_column("foo", col("sepal_length") + 1)
    assert unoptimized.column_names == [*df.column_names, "foo"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_with_column_partial_predicate_pushdown(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = (
        df.with_column("foo", col("sepal_length") + 1).where(col("sepal_length") > 4.8).where(col("foo") > 4.8)
    )
    optimized = df.where(col("sepal_length") > 4.8).with_column("foo", col("sepal_length") + 1).where(col("foo") > 4.8)
    assert unoptimized.column_names == [*df.column_names, "foo"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_with_column_alias(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.with_column("foo", col("sepal_length").alias("foo") + 1).where(
        col("sepal_length").alias("foo") > 4.8
    )
    optimized = df.where(col("sepal_length").alias("foo") > 4.8).with_column(
        "foo", col("sepal_length").alias("foo") + 1
    )
    assert unoptimized.column_names == [*df.column_names, "foo"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_merge(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.where((col("sepal_length") > 4.8).alias("foo")).where((col("sepal_width") > 2.4).alias("foo"))

    # HACK: We manually modify the plan here because currently CombineFilters works by combining predicates as an ExpressionList rather than taking the & of the two predicates
    DUMMY = col("sepal_width") > 100
    EXPECTED = ExpressionList(
        [(col("sepal_width") > 2.4).alias("foo"), (col("sepal_length") > 4.8).alias("foo").alias("copy.foo")]
    )
    optimized = df.where(DUMMY)
    optimized._plan._predicate = EXPECTED

    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_sort(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.sort("sepal_length").select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8)
    optimized = df.where(col("sepal_length") > 4.8).sort("sepal_length").select("sepal_length", "sepal_width")
    assert unoptimized.column_names == ["sepal_length", "sepal_width"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_pushdown_repartition(valid_data: list[dict[str, float]], optimizer) -> None:
    df = DataFrame.from_pylist(valid_data)
    unoptimized = df.repartition(2).select("sepal_length", "sepal_width").where(col("sepal_length") > 4.8)
    optimized = df.where(col("sepal_length") > 4.8).repartition(2).select("sepal_length", "sepal_width")
    assert unoptimized.column_names == ["sepal_length", "sepal_width"]
    assert_plan_eq(optimizer(unoptimized.plan()), optimized.plan())


def test_filter_join_pushdown(valid_data: list[dict[str, float]], optimizer) -> None:
    df1 = DataFrame.from_pylist(valid_data)
    df2 = DataFrame.from_pylist(valid_data)

    joined = df1.join(df2, on="variety")

    filtered = joined.where(col("sepal_length") > 4.8)
    filtered = filtered.where(col("right.sepal_width") > 4.8)

    optimized = optimizer(filtered.plan())

    expected = df1.where(col("sepal_length") > 4.8).join(df2.where(col("sepal_width") > 4.8), on="variety")
    assert isinstance(optimized, Join)
    assert isinstance(expected.plan(), Join)
    assert_plan_eq(optimized, expected.plan())


def test_filter_join_pushdown_aliases(valid_data: list[dict[str, float]], optimizer) -> None:
    df1 = DataFrame.from_pylist(valid_data)
    df2 = DataFrame.from_pylist(valid_data)

    joined = df1.join(df2, on="variety")

    filtered = joined.where(col("sepal_length").alias("foo") > 4.8)
    filtered = filtered.where(col("right.sepal_width").alias("foo") > 4.8)

    optimized = optimizer(filtered.plan())

    expected = df1.where(
        # Filter merging creates a `copy.*` column when merging predicates with the same name
        (col("sepal_length").alias("foo") > 4.8).alias("copy.foo")
    ).join(df2.where(col("sepal_width").alias("foo") > 4.8), on="variety")
    assert isinstance(optimized, Join)
    assert isinstance(expected.plan(), Join)
    assert_plan_eq(optimized, expected.plan())


def test_filter_join_pushdown_nonvalid(valid_data: list[dict[str, float]], optimizer) -> None:
    df1 = DataFrame.from_pylist(valid_data)
    df2 = DataFrame.from_pylist(valid_data)

    joined = df1.join(df2, on="variety")

    filtered = joined.where(col("right.sepal_width") > col("sepal_length"))

    optimized = optimizer(filtered.plan())

    assert isinstance(optimized, Filter)
    assert_plan_eq(optimized, filtered.plan())


def test_filter_join_pushdown_nonvalid_aliases(valid_data: list[dict[str, float]], optimizer) -> None:
    df1 = DataFrame.from_pylist(valid_data)
    df2 = DataFrame.from_pylist(valid_data)

    joined = df1.join(df2, on="variety")

    filtered = joined.where(col("right.sepal_width").alias("sepal_width") > col("sepal_length"))

    optimized = optimizer(filtered.plan())

    assert isinstance(optimized, Filter)
    assert_plan_eq(optimized, filtered.plan())


def test_filter_join_partial_predicate_pushdown(valid_data: list[dict[str, float]], optimizer) -> None:
    df1 = DataFrame.from_pylist(valid_data)
    df2 = DataFrame.from_pylist(valid_data)

    joined = df1.join(df2, on="variety")

    filtered = joined.where(col("sepal_length") > 4.8)
    filtered = filtered.where(col("right.sepal_width") > 4.8)
    filtered = filtered.where(((col("sepal_length") > 4.8) | (col("right.sepal_length") > 4.8)).alias("foo"))

    optimized = optimizer(filtered.plan())

    expected = (
        df1.where(col("sepal_length") > 4.8)
        .join(df2.where(col("sepal_width") > 4.8), on="variety")
        .where(((col("sepal_length") > 4.8) | (col("right.sepal_length") > 4.8)).alias("foo"))
    )
    assert isinstance(optimized, Filter)
    assert isinstance(expected.plan(), Filter)
    assert_plan_eq(optimized, expected.plan())
