from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import builtins
import  math
import warnings
import inspect
from functools import partial
import tensorflow as tf
from trident.backend.common import TensorShape
from trident.backend.tensorflow_backend import *
from trident.backend.tensorflow_ops import *
from trident.backend.common import get_function, camel2snake
__all__ = ['kaiming_uniform', 'kaiming_normal','xavier_uniform','xavier_normal','trunc_normal','fill_zeros','fill_ones']

def calculate_gain(nonlinearity, param=None):
    r"""Return the recommended gain value for the given nonlinearity function.
    The values are as follows:

    ================= ====================================================
    nonlinearity      gain
    ================= ====================================================
    Linear / Identity :math:`1`
    Conv{1,2,3}D      :math:`1`
    Sigmoid           :math:`1`
    Tanh              :math:`\frac{5}{3}`
    ReLU              :math:`\sqrt{2}`
    Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
    ================= ====================================================

    Args:
        nonlinearity: the non-linear function (`nn.functional` name)
        param: optional parameter for the non-linear function

    Examples:
        >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
    """
    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
        return 1
    elif nonlinearity == 'tanh':
        return 5.0 / 3
    elif nonlinearity == 'relu':
        return math.sqrt(2.0)
    elif nonlinearity == 'leaky_relu':
        if param is None:
            negative_slope = 0.01
        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
            # True/False are instances of int, hence check above
            negative_slope = param
        else:
            raise ValueError("negative_slope {} not a valid number".format(param))
        return math.sqrt(2.0 / (1 + negative_slope ** 2))
    else:
        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))

def _calculate_fan_in_and_fan_out(tensor):
    dimensions = len(tensor.shape)
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    num_input_fmaps = int_shape(tensor)[-1]
    num_output_fmaps =  int_shape(tensor)[0]
    receptive_field_size = 1
    if dimensions > 2:
        receptive_field_size = tensor[0][0].numel()
    fan_in = num_input_fmaps * receptive_field_size
    fan_out = num_output_fmaps * receptive_field_size
    return fan_in, fan_out

def _calculate_correct_fan(tensor, mode):
    mode = mode.lower()
    valid_modes = ['fan_in', 'fan_out']
    if mode not in valid_modes:
        raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))

    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    return fan_in if mode == 'fan_in' else fan_out



def uniform(tensor, a=0., b=1.):
    # type: (Tensor, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from the uniform
    distribution :math:`\mathcal{U}(a, b)`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the lower bound of the uniform distribution
        b: the upper bound of the uniform distribution

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.uniform_(w)
    """

    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                weight.assign(random_uniform_like(weight, a=a,b=b))
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        tensor.assign(random_uniform_like(tensor, a=a,b=b))


def normal(tensor, mean=0., std=1.):
    # type: (Tensor, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from the normal
    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.normal_(w)
    """

    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                weight.assign(random_normal_like(weight,mean=mean,std=std))
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        tensor.assign(random_normal_like(tensor,mean=mean,std=std))



def fill_zeros(tensor):
    # type: (Tensor) -> Tensor
    r"""Fills the input Tensor with the scalar value `0`.

    Args:
        tensor: an n-dimensional `torch.Tensor`

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.zeros_(w)
    """

    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable :
                weight.assign(zeros_like(weight))
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        tensor.assign(zeros_like(tensor))


def fill_ones(tensor):
    # type: (Tensor) -> Tensor
    r"""Fills the input Tensor with the scalar value `1`.

    Args:
        tensor: an n-dimensional `torch.Tensor`

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.ones_(w)
    """
    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                weight.assign(ones_like(weight))
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        tensor.assign(ones_like(tensor))


def kaiming_uniform(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    uniform distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where

    .. math::
        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = zeros((3, 5))
        >>> kaiming_uniform(w, mode='fan_in', nonlinearity='relu')
    """

    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                kaiming_uniform(weight, a, mode, nonlinearity)

    elif isinstance(tensor, tf.Variable) and tensor.trainable == True:
        tensor_data = tensor.value()
        fan = to_numpy(_calculate_correct_fan(tensor_data, mode)).mean()
        gain = calculate_gain(nonlinearity, a)
        std = true_divide(gain, math.sqrt(fan))
        bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
        tensor.assign(random_uniform_like(tensor_data, -bound, bound, tensor_data.dtype))


def kaiming_normal(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
    """
    if isinstance(tensor, tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                kaiming_normal(weight, a, mode, nonlinearity)
    elif isinstance(tensor, tf.Variable) and tensor.trainable == True:
        tensor_data=tensor.value()
        fan = to_numpy(_calculate_correct_fan(tensor_data, mode)).mean()
        gain = calculate_gain(nonlinearity, a)
        std = true_divide(gain , math.sqrt(fan))
        tensor.assign(random_normal_like(tensor_data,0, std, tensor_data.dtype))


def xavier_uniform(tensor, gain=1.):
    # type: (Tensor, float) -> Tensor
    r"""Fills the input `Tensor` with values according to the method
    described in `Understanding the difficulty of training deep feedforward
    neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-a, a)` where

    .. math::
        a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}

    Also known as Glorot initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        gain: an optional scaling factor

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
    """
    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                xavier_uniform(weight, gain)
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
        std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
        a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation

        tensor.assign(random_uniform_like(tensor, -a, a))

def xavier_normal(tensor, gain=1.):
    # type: (Tensor, float) -> Tensor
    r"""Fills the input `Tensor` with values according to the method
    described in `Understanding the difficulty of training deep feedforward
    neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}

    Also known as Glorot initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        gain: an optional scaling factor

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.xavier_normal_(w)
    """
    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                xavier_normal(weight, gain)
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:
        fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
        std = gain * math.sqrt(2.0 / float(fan_in + fan_out))

        tensor.assign(random_normal_like(tensor, 0, std))

def trunc_normal(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    if isinstance(tensor,tf.Module):
        for name,weight in tensor.named_parameters():
            if weight.trainable==True and 'bias' not in name:
                weight.assign(tf.random.truncated_normal(weight.shape,mean=mean, std=std, a=a, b=b))
    elif isinstance(tensor, tf.Variable) and tensor.trainable==True:

        tensor.assign(tf.random.truncated_normal(tensor.shape,mean=mean, std=std, a=a, b=b))



def get_initializer(initializer,**kwargs):
    if isinstance(initializer,str):
        initializer_fn = get_function(camel2snake(initializer), ['trident.backend.tensorflow_initializers'])
        initializer_fn=partial(initializer_fn,**kwargs) if len(kwargs)>0 else initializer_fn
        return initializer_fn
    elif inspect.isfunction(initializer) and getattr(initializer, '__module__', None) =='trident.backend.tensorflow_initializers':
        initializer = partial(initializer, **kwargs) if len(kwargs) > 0 else initializer
        return initializer

