# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import pandas as pd
import sklearn.datasets
import sklearn.model_selection

import lale.lib.autoai_libs

# from lale.datasets.uci import fetch_household_power_consumption
from lale.lib.autoai_libs import float32_transform
from lale.lib.lale import Hyperopt
from lale.lib.sklearn import LogisticRegression as LR
from lale.lib.xgboost.xgb_classifier import XGBClassifier


class TestAutoaiLibs(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        iris = sklearn.datasets.load_iris()
        iris_X, iris_y = iris.data, iris.target
        (
            iris_train_X,
            iris_test_X,
            iris_train_y,
            iris_test_y,
        ) = sklearn.model_selection.train_test_split(iris_X, iris_y)
        cls._iris = {
            "train_X": iris_train_X,
            "train_y": iris_train_y,
            "test_X": iris_test_X,
            "test_y": iris_test_y,
        }

    def doTest(self, trainable, train_X, train_y, test_X, test_y):
        trained = trainable.fit(train_X, train_y)
        _ = trained.transform(test_X)
        with self.assertWarns(DeprecationWarning):
            trainable.transform(train_X)
        trainable.to_json()
        trainable_pipeline = trainable >> float32_transform() >> LR()
        trained_pipeline = trainable_pipeline.fit(train_X, train_y)
        trained_pipeline.predict(test_X)
        hyperopt = Hyperopt(estimator=trainable_pipeline, max_evals=1, verbose=True)
        trained_hyperopt = hyperopt.fit(train_X, train_y)
        trained_hyperopt.predict(test_X)

    def test_NumpyColumnSelector(self):
        trainable = lale.lib.autoai_libs.NumpyColumnSelector()
        self.doTest(trainable, **self._iris)

    def test_NumpyColumnSelector_pandas(self):
        iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True, as_frame=True)
        keys = ["train_X", "test_X", "train_y", "test_y"]
        splits = sklearn.model_selection.train_test_split(iris_X, iris_y)
        iris = {key: data for key, data in zip(keys, splits)}
        self.assertIsInstance(iris["train_X"], pd.DataFrame)
        trainable = lale.lib.autoai_libs.NumpyColumnSelector(columns=[0, 2, 3])
        self.doTest(trainable, **iris)

    def test_CompressStrings(self):
        n_columns = self._iris["train_X"].shape[1]
        trainable = lale.lib.autoai_libs.CompressStrings(
            dtypes_list=["int_num" for i in range(n_columns)],
            misslist_list=[[] for i in range(n_columns)],
        )
        self.doTest(trainable, **self._iris)

    def test_NumpyReplaceMissingValues(self):
        trainable = lale.lib.autoai_libs.NumpyReplaceMissingValues()
        self.doTest(trainable, **self._iris)

    def test_NumpyReplaceUnknownValues(self):
        trainable = lale.lib.autoai_libs.NumpyReplaceUnknownValues(filling_values=42.0)
        self.doTest(trainable, **self._iris)

    def test_boolean2float(self):
        trainable = lale.lib.autoai_libs.boolean2float()
        self.doTest(trainable, **self._iris)

    def test_CatImputer(self):
        trainable = lale.lib.autoai_libs.CatImputer()
        self.doTest(trainable, **self._iris)

    def test_CatEncoder(self):
        trainable = lale.lib.autoai_libs.CatEncoder(
            encoding="ordinal",
            categories="auto",
            dtype="float64",
            handle_unknown="ignore",
        )
        self.doTest(trainable, **self._iris)

    def test_float32_transform(self):
        trainable = lale.lib.autoai_libs.float32_transform()
        self.doTest(trainable, **self._iris)

    def test_FloatStr2Float(self):
        n_columns = self._iris["train_X"].shape[1]
        trainable = lale.lib.autoai_libs.FloatStr2Float(
            dtypes_list=["int_num" for i in range(n_columns)]
        )
        self.doTest(trainable, **self._iris)

    def test_OptStandardScaler(self):
        trainable = lale.lib.autoai_libs.OptStandardScaler()
        self.doTest(trainable, **self._iris)

    def test_NumImputer(self):
        trainable = lale.lib.autoai_libs.NumImputer()
        self.doTest(trainable, **self._iris)

    def test_NumpyPermuteArray(self):
        trainable = lale.lib.autoai_libs.NumpyPermuteArray(
            axis=0, permutation_indices=[2, 0, 1, 3]
        )
        self.doTest(trainable, **self._iris)

    def test_TNoOp(self):
        from autoai_libs.utils.fc_methods import is_not_categorical

        trainable = lale.lib.autoai_libs.TNoOp(
            fun=np.rint,
            name="do nothing",
            datatypes=["numeric"],
            feat_constraints=[is_not_categorical],
        )
        self.doTest(trainable, **self._iris)

    def test_TA1(self):
        from autoai_libs.utils.fc_methods import is_not_categorical

        float32 = np.dtype("float32")
        trainable = lale.lib.autoai_libs.TA1(
            fun=np.rint,
            name="round",
            datatypes=["numeric"],
            feat_constraints=[is_not_categorical],
            col_names=["a", "b", "c", "d"],
            col_dtypes=[float32, float32, float32, float32],
        )
        self.doTest(trainable, **self._iris)

    def test_TA2(self):
        from autoai_libs.utils.fc_methods import is_not_categorical

        float32 = np.dtype("float32")
        trainable = lale.lib.autoai_libs.TA2(
            fun=np.add,
            name="sum",
            datatypes1=["numeric"],
            feat_constraints1=[is_not_categorical],
            datatypes2=["numeric"],
            feat_constraints2=[is_not_categorical],
            col_names=["a", "b", "c", "d"],
            col_dtypes=[float32, float32, float32, float32],
        )
        self.doTest(trainable, **self._iris)

    def test_TB1(self):
        from autoai_libs.utils.fc_methods import is_not_categorical
        from sklearn.preprocessing import StandardScaler

        float32 = np.dtype("float32")
        trainable = lale.lib.autoai_libs.TB1(
            tans_class=StandardScaler,
            name="stdscaler",
            datatypes=["numeric"],
            feat_constraints=[is_not_categorical],
            col_names=["a", "b", "c", "d"],
            col_dtypes=[float32, float32, float32, float32],
        )
        self.doTest(trainable, **self._iris)

    def test_TB2(self):
        pass  # TODO: not sure how to instantiate, what to pass for tans_class

    def test_TAM(self):
        from autoai_libs.cognito.transforms.transform_extras import (
            IsolationForestAnomaly,
        )

        float32 = np.dtype("float32")
        trainable = lale.lib.autoai_libs.TAM(
            tans_class=IsolationForestAnomaly,
            name="isoforestanomaly",
            col_names=["a", "b", "c", "d"],
            col_dtypes=[float32, float32, float32, float32],
        )
        self.doTest(trainable, **self._iris)

    def test_TGen(self):
        from autoai_libs.cognito.transforms.transform_extras import NXOR
        from autoai_libs.utils.fc_methods import is_not_categorical

        float32 = np.dtype("float32")
        trainable = lale.lib.autoai_libs.TGen(
            fun=NXOR,
            name="nxor",
            arg_count=2,
            datatypes_list=[["numeric"], ["numeric"]],
            feat_constraints_list=[[is_not_categorical], [is_not_categorical]],
            col_names=["a", "b", "c", "d"],
            col_dtypes=[float32, float32, float32, float32],
        )
        self.doTest(trainable, **self._iris)

    def test_FS1(self):
        trainable = lale.lib.autoai_libs.FS1(
            cols_ids_must_keep=[1],
            additional_col_count_to_keep=3,
            ptype="classification",
        )
        self.doTest(trainable, **self._iris)

    def test_FS2(self):
        from sklearn.ensemble import ExtraTreesClassifier

        trainable = lale.lib.autoai_libs.FS2(
            cols_ids_must_keep=[1],
            additional_col_count_to_keep=3,
            ptype="classification",
            eval_algo=ExtraTreesClassifier,
        )
        self.doTest(trainable, **self._iris)

    def test_ColumnSelector(self):
        trainable = lale.lib.autoai_libs.ColumnSelector()
        self.doTest(trainable, **self._iris)

    def test_ColumnSelector_pandas(self):
        iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True, as_frame=True)
        keys = ["train_X", "test_X", "train_y", "test_y"]
        splits = sklearn.model_selection.train_test_split(iris_X, iris_y)
        iris = {key: data for key, data in zip(keys, splits)}
        self.assertIsInstance(iris["train_X"], pd.DataFrame)
        trainable = lale.lib.autoai_libs.ColumnSelector(columns_indices_list=[0, 2, 3])
        self.doTest(trainable, **iris)


class TestAutoaiLibsText(unittest.TestCase):
    def setUp(self):
        from sklearn.datasets import fetch_20newsgroups

        cats = ["alt.atheism", "sci.space"]
        newsgroups_train = fetch_20newsgroups(subset="train", categories=cats)
        self.train_X, self.train_y = (
            np.array(newsgroups_train.data),
            newsgroups_train.target,
        )
        self.train_X = np.reshape(self.train_X, (self.train_X.shape[0], 1))
        newsgroups_test = fetch_20newsgroups(subset="test", categories=cats)
        self.test_X, self.test_y = (
            np.array(newsgroups_test.data),
            newsgroups_test.target,
        )
        self.test_X = np.reshape(self.test_X, (self.test_X.shape[0], 1))

    def doTest(self, trainable, train_X, train_y, test_X, test_y):
        trained = trainable.fit(train_X, train_y)
        _ = trained.transform(test_X)
        with self.assertWarns(DeprecationWarning):
            trainable.transform(train_X)
        trainable.to_json()
        trainable_pipeline = trainable >> float32_transform() >> XGBClassifier()
        trained_pipeline = trainable_pipeline.fit(train_X, train_y)
        trained_pipeline.predict(test_X)
        hyperopt = Hyperopt(estimator=trainable_pipeline, max_evals=1, verbose=True)
        trained_hyperopt = hyperopt.fit(train_X, train_y)
        trained_hyperopt.predict(test_X)

    @unittest.skip(
        "skipping for now because this does not work with the latest xgboost."
    )
    def test_TextTransformer(self):
        trainable = lale.lib.autoai_libs.TextTransformer(
            drop_columns=True,
            columns_to_be_deleted=[0, 1],
            text_processing_options={"word2vec": {"output_dim": 5}},
        )
        self.doTest(trainable, self.train_X, self.train_y, self.test_X, self.test_y)

    @unittest.skip(
        "skipping for now because this does not work with the latest xgboost."
    )
    def test_Word2VecTransformer(self):
        trainable = lale.lib.autoai_libs.Word2VecTransformer(
            drop_columns=True, output_dim=5
        )
        self.doTest(trainable, self.train_X, self.train_y, self.test_X, self.test_y)


# class TestDateTransformer(unittest.TestCase):
#     @classmethod
#     def setUpClass(cls):
#         data = fetch_household_power_consumption()
#         data = data.iloc[:5000, [0, 2, 3, 4, 5]]
#         cls.X_train = data.iloc[-1000:]
#         cls.X_test = data.iloc[:-1000]

#     def test_01_all_mini_options_with_headers(self):
#         transformer = lale.lib.autoai_libs.DateTransformer(
#             options=["all"], column_headers_list=self.X_train.columns.values.tolist()
#         )
#         fitted_transformer = transformer.fit(self.X_train.values)
#         X_test_transformed = fitted_transformer.transform(self.X_test.values)
#         X_train_transformed = fitted_transformer.transform(self.X_train.values)

#         header_list = fitted_transformer.impl.new_column_headers_list
#         print(f"New columns: {header_list}, new shape: {X_train_transformed.shape}")

#         self.assertEqual(
#             X_train_transformed.shape[1],
#             X_test_transformed.shape[1],
#             f"Number of columns after transform is different.:{X_train_transformed.shape[1]}, {X_test_transformed.shape[1]}",
#         )

#     def test_02_all_options_without_headers(self):
#         transformer = lale.lib.autoai_libs.DateTransformer(options=["all"])
#         fitted_transformer = transformer.fit(self.X_train.values)
#         X_train = fitted_transformer.transform(self.X_train.values)
#         X_test = transformer.transform(self.X_test.values)
#         header_list = fitted_transformer.impl.new_column_headers_list
#         print(f"New columns: {header_list}")

#         self.assertEqual(
#             X_train.shape[1], X_test.shape[1], msg="Shape after transform is different."
#         )

#     def test_03_specific_options_and_delete_source_columns(self):
#         transformer = lale.lib.autoai_libs.DateTransformer(
#             options=["FloatTimestamp", "DayOfWeek", "Hour", "Minute"],
#             delete_source_columns=True,
#             column_headers_list=self.X_train.columns.values.tolist(),
#         )
#         fitted_transformer = transformer.fit(self.X_train.values)
#         X_train = fitted_transformer.transform(self.X_train.values)
#         X_test = transformer.transform(self.X_test.values)
#         header_list = fitted_transformer.impl.new_column_headers_list
#         print(f"New columns: {header_list}")

#         self.assertEqual(
#             X_train.shape[1], X_test.shape[1], msg="Shape after transform is different."
#         )

#     def test_04_option_Datetime_and_delete_source_columns(self):
#         transformer = lale.lib.autoai_libs.DateTransformer(
#             options=["Datetime"],
#             delete_source_columns=True,
#             column_headers_list=self.X_train.columns.values.tolist(),
#         )
#         fitted_transformer = transformer.fit(self.X_train.values)
#         X_train = fitted_transformer.transform(self.X_train.values)
#         X_test = transformer.transform(self.X_test.values)
#         header_list = fitted_transformer.impl.new_column_headers_list
#         print(f"New columns: {header_list}")

#         self.assertEqual(
#             X_train.shape[1], X_test.shape[1], msg="Shape after transform is different."
#         )
