# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05d_methods.counternet.ipynb.

# %% ../../nbs/05d_methods.counternet.ipynb 3
from __future__ import annotations
from ..import_essentials import *
from ..module import MLP, BaseTrainingModule
from .base import BaseCFModule, BaseParametricCFModule, BasePredFnCFModule
from ..train import TrainingConfigs, train_model
from ..data import TabularDataModule
from cfnet.utils import (
    validate_configs,
    sigmoid,
    accuracy,
    proximity,
    make_model,
    init_net_opt,
    grad_update,
)
from functools import partial
from copy import deepcopy

# %% auto 0
__all__ = ['CounterNetModel', 'partition_trainable_params', 'project_immutable_features', 'CounterNetTrainingModuleConfigs',
           'CounterNetTrainingModule', 'CounterNetConfigs', 'CounterNet']

# %% ../../nbs/05d_methods.counternet.ipynb 5
class CounterNetModelConfigs(BaseParser):
    """Configurator of `CounterNetModel`."""

    enc_sizes: List[int] = Field(description='Encoder sizes.')
    dec_sizes: List[int] = Field(description='Predictor sizes.')
    exp_sizes: List[int] = Field(description='CF generator sizes.')
    dropout_rate: float = Field(0.3, description='Dropout rate.')

# %% ../../nbs/05d_methods.counternet.ipynb 6
class CounterNetModel(hk.Module):
    """CounterNet Model"""
    def __init__(
        self,
        m_config: Dict | CounterNetModelConfigs,  # Model configs which contain configs in `CounterNetModelConfigs`.
        name: str = None,  # Name of the module.
    ):
        """CounterNet model architecture."""
        super().__init__(name=name)
        self.configs = validate_configs(m_config, CounterNetModelConfigs)

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        input_shape = x.shape[-1]
        # encoder
        z = MLP(self.configs.enc_sizes, self.configs.dropout_rate, name="Encoder")(
            x, is_training
        )

        # prediction
        pred = MLP(self.configs.dec_sizes, self.configs.dropout_rate, name="Predictor")(
            z, is_training
        )
        y_hat = hk.Linear(1, name="Predictor")(pred)
        y_hat = sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(
            z_exp, is_training
        )
        cf = hk.Linear(input_shape, name="Explainer")(cf)
        return y_hat, cf


# %% ../../nbs/05d_methods.counternet.ipynb 10
def partition_trainable_params(params: hk.Params, trainable_name: str):
    trainable_params, non_trainable_params = hk.data_structures.partition(
        lambda m, n, p: trainable_name in m, params
    )
    return trainable_params, non_trainable_params


# %% ../../nbs/05d_methods.counternet.ipynb 11
def project_immutable_features(x, cf: jnp.DeviceArray, imutable_idx_list: List[int]):
    cf = cf.at[:, imutable_idx_list].set(x[:, imutable_idx_list])
    return cf

# %% ../../nbs/05d_methods.counternet.ipynb 12
class CounterNetTrainingModuleConfigs(BaseParser):
    lr: float = 0.003
    lambda_1: float = 1.0
    lambda_2: float = 0.2
    lambda_3: float = 0.1


# %% ../../nbs/05d_methods.counternet.ipynb 13
class CounterNetTrainingModule(BaseTrainingModule):
    _data_module: TabularDataModule

    def __init__(self, m_configs: Dict[str, Any]):
        self.save_hyperparameters(m_configs)
        self.net = make_model(m_configs, CounterNetModel)
        self.configs = validate_configs(m_configs, CounterNetTrainingModuleConfigs)
        # self.configs = CounterNetTrainingModuleConfigs(**m_configs)
        self.opt_1 = optax.adam(learning_rate=self.configs.lr)
        self.opt_2 = optax.adam(learning_rate=self.configs.lr)

    def init_net_opt(self, data_module: TabularDataModule, key):
        # hook data_module
        self._data_module = data_module
        X, _ = data_module.train_dataset[:]

        # manually init multiple opts
        params, opt_1_state = init_net_opt(
            self.net, self.opt_1, X=X[:100], key=key
        )
        trainable_params, _ = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        opt_2_state = self.opt_2.init(trainable_params)
        return params, (opt_1_state, opt_2_state)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        # first forward to get y_pred and normalized cf
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        cf = self._data_module.apply_constraints(x, cf, hard=not is_training)

        # second forward to calulate cf_y
        cf_y, _ = self.net.apply(params, rng_key, cf, is_training=is_training)
        return y_pred, cf, cf_y

    def predict(self, params, rng_key, x):
        y_pred, _ = self.net.apply(params, rng_key, x, is_training=False)
        return y_pred

    def generate_cfs(self, X: chex.ArrayBatched, params, rng_key) -> chex.ArrayBatched:
        y_pred, cfs = self.net.apply(params, rng_key, X, is_training=False)
        # cfs = cfs + X
        cfs = self._data_module.apply_constraints(X, cfs, hard=True)
        return cfs

    def loss_fn_1(self, y_pred, y):
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    def loss_fn_2(self, cf_y, y_prime):
        return jnp.mean(vmap(optax.l2_loss)(cf_y, y_prime))

    def loss_fn_3(self, x, cf):
        return jnp.mean(vmap(optax.l2_loss)(x, cf))

    def pred_loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        return self.configs.lambda_1 * self.loss_fn_1(y_pred, y)

    def exp_loss_fn(
        self,
        trainable_params,
        non_trainable_params,
        rng_key,
        batch,
        is_training: bool = True,
    ):
        # merge trainable and non_trainable params
        params = hk.data_structures.merge(trainable_params, non_trainable_params)
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
        y_prime = 1 - jnp.round(y_pred)
        loss_2, loss_3 = self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        return self.configs.lambda_2 * loss_2 + self.configs.lambda_3 * loss_3

    def _predictor_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.pred_loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt_1)
        return upt_params, opt_state

    def _explainer_step(self, params, opt_state, rng_key, batch):
        trainable_params, non_trainable_params = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        grads = jax.grad(self.exp_loss_fn)(
            trainable_params, non_trainable_params, rng_key, batch
        )
        upt_trainable_params, opt_state = grad_update(
            grads, trainable_params, opt_state, self.opt_2
        )
        upt_params = hk.data_structures.merge(
            upt_trainable_params, non_trainable_params
        )
        return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.GradientTransformation, optax.GradientTransformation],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ):
        opt_1_state, opt_2_state = opts_state
        params, opt_1_state = self._predictor_step(params, opt_1_state, rng_key, batch)
        upt_params, opt_2_state = self._explainer_step(
            params, opt_2_state, rng_key, batch
        )
        return upt_params, (opt_1_state, opt_2_state)

    def _training_step_logs(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        logs = {
            "train/train_loss_1": loss_1.item(),
            "train/train_loss_2": loss_2.item(),
            "train/train_loss_3": loss_3.item(),
        }
        return logs

    def training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.OptState, optax.OptState],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, Tuple[optax.OptState, optax.OptState]]:
        upt_params, (opt_1_state, opt_2_state) = self._training_step(
            params, opts_state, rng_key, batch
        )

        logs = self._training_step_logs(upt_params, rng_key, batch)
        self.log_dict(logs)
        return upt_params, (opt_1_state, opt_2_state)

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        loss_1, loss_2, loss_3 = map(np.asarray, (loss_1, loss_2, loss_3))
        logs = {
            "val/accuracy": accuracy(y, y_pred),
            "val/validity": accuracy(cf_y, y_prime),
            "val/proximity": proximity(x, cf),
            "val/val_loss_1": loss_1,
            "val/val_loss_2": loss_2,
            "val/val_loss_3": loss_3,
            "val/val_loss": loss_1 + loss_2 + loss_3,
        }
        self.log_dict(logs)
        return logs

# %% ../../nbs/05d_methods.counternet.ipynb 18
class CounterNetConfigs(CounterNetTrainingModuleConfigs, CounterNetModelConfigs):
    """Configurator of `CounterNet`."""

    enc_sizes: List[int] = Field(
        [50,10], description="Sequence of layer sizes for encoder network."
    )
    dec_sizes: List[int] = Field(
        [10], description="Sequence of layer sizes for predictor."
    ) 
    exp_sizes: List[int] = Field(
        [50, 50], description="Sequence of layer sizes for CF generator."
    )
    
    dropout_rate: float = Field(
        0.3, description="Dropout rate."
    )
    lr: float = Field(
        0.003, description="Learning rate for training `CounterNet`."
    ) 
    lambda_1: float = Field(
        1.0, description=" $\lambda_1$ for balancing the prediction loss $\mathcal{L}_1$."
    ) 
    lambda_2: float = Field(
        0.2, description=" $\lambda_2$ for balancing the prediction loss $\mathcal{L}_2$."
    ) 
    lambda_3: float = Field(
        0.1, description=" $\lambda_3$ for balancing the prediction loss $\mathcal{L}_3$."
    )


# %% ../../nbs/05d_methods.counternet.ipynb 20
class CounterNet(BaseCFModule, BaseParametricCFModule, BasePredFnCFModule):
    """API for CounterNet Explanation Module."""
    params: hk.Params = None
    module: CounterNetTrainingModule
    name: str = 'CounterNet'

    def __init__(
        self, 
        m_configs: dict | CounterNetConfigs = None # configurator of hyperparamters; see `CounterNetConfigs`
    ):
        if m_configs is None:
            m_configs = CounterNetConfigs()
        self.module = CounterNetTrainingModule(m_configs)

    def _is_module_trained(self):
        return not (self.params is None)
    
    def train(
        self, 
        datamodule: TabularDataModule, # data module
        t_configs: TrainingConfigs | dict = None # training configs
    ):
        _default_t_configs = dict(
            n_epochs=100, batch_size=128
        )
        if t_configs is None: t_configs = _default_t_configs
        params, _ = train_model(self.module, datamodule, t_configs)
        self.params = params

    def generate_cfs(self, X: jnp.ndarray, pred_fn = None) -> jnp.ndarray:
        return self.module.generate_cfs(X, self.params, rng_key=jax.random.PRNGKey(0))
    
    def pred_fn(self, X: jnp.DeviceArray):
        rng_key = jax.random.PRNGKey(0)
        y_pred = self.module.predict(self.params, rng_key, X)
        return y_pred

