# Copyright 2018 The KaiJIN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
import torch.nn as nn
from torch.nn.modules.utils import _triple


class SpatioTemporalConv(nn.Module):
  r"""Applies a factored 3D convolution over an input signal composed of several input
  planes with distinct spatial and time axes, by performing a 2D convolution over the
  spatial axes to an intermediate subspace, followed by a 1D convolution over the time
  axis to produce the final output.
  Args:
      in_channels (int): Number of channels in the input tensor
      out_channels (int): Number of channels produced by the convolution
      kernel_size (int or tuple): Size of the convolving kernel
      stride (int or tuple, optional): Stride of the convolution. Default: 1
      padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0
      bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  """

  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
    super(SpatioTemporalConv, self).__init__()

    # if ints are entered, convert them to iterables, 1 -> [1, 1, 1]
    kernel_size = _triple(kernel_size)
    stride = _triple(stride)
    padding = _triple(padding)

    self.temporal_spatial_conv = nn.Conv3d(in_channels, out_channels, kernel_size,
                                           stride=stride, padding=padding, bias=bias)
    self.bn = nn.BatchNorm3d(out_channels)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.bn(self.temporal_spatial_conv(x))
    x = self.relu(x)
    return x


class SpatioTemporalResBlock(nn.Module):
  r"""Single block for the ResNet network. Uses SpatioTemporalConv in
      the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU)

      Args:
          in_channels (int): Number of channels in the input tensor.
          out_channels (int): Number of channels in the output produced by the block.
          kernel_size (int or tuple): Size of the convolving kernels.
          downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False``
      """

  def __init__(self, in_channels, out_channels, kernel_size, downsample=False):
    super(SpatioTemporalResBlock, self).__init__()

    # If downsample == True, the first conv of the layer has stride = 2
    # to halve the residual output size, and the input x is passed
    # through a seperate 1x1x1 conv with stride = 2 to also halve it.

    # no pooling layers are used inside ResNet
    self.downsample = downsample

    # to allow for SAME padding
    padding = kernel_size // 2

    if self.downsample:
      # downsample with stride =2 the input x
      self.downsampleconv = SpatioTemporalConv(
          in_channels, out_channels, 1, stride=2)
      self.downsamplebn = nn.BatchNorm3d(out_channels)

      # downsample with stride = 2when producing the residual
      self.conv1 = SpatioTemporalConv(
          in_channels, out_channels, kernel_size, padding=padding, stride=2)
    else:
      self.conv1 = SpatioTemporalConv(
          in_channels, out_channels, kernel_size, padding=padding)

    self.bn1 = nn.BatchNorm3d(out_channels)
    self.relu1 = nn.ReLU()

    # standard conv->batchnorm->ReLU
    self.conv2 = SpatioTemporalConv(
        out_channels, out_channels, kernel_size, padding=padding)
    self.bn2 = nn.BatchNorm3d(out_channels)
    self.outrelu = nn.ReLU()

  def forward(self, x):
    res = self.relu1(self.bn1(self.conv1(x)))
    res = self.bn2(self.conv2(res))

    if self.downsample:
      x = self.downsamplebn(self.downsampleconv(x))

    return self.outrelu(x + res)


class SpatioTemporalResLayer(nn.Module):
  r"""Forms a single layer of the ResNet network, with a number of repeating
  blocks of same output size stacked on top of each other

      Args:
          in_channels (int): Number of channels in the input tensor.
          out_channels (int): Number of channels in the output produced by the layer.
          kernel_size (int or tuple): Size of the convolving kernels.
          layer_size (int): Number of blocks to be stacked to form the layer
          block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock.
          downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False``
      """

  def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock,
               downsample=False):

    super(SpatioTemporalResLayer, self).__init__()

    # implement the first block
    self.block1 = block_type(in_channels, out_channels,
                             kernel_size, downsample)

    # prepare module list to hold all (layer_size - 1) blocks
    self.blocks = nn.ModuleList([])
    for i in range(layer_size - 1):
      # all these blocks are identical, and have downsample = False by default
      self.blocks += [block_type(out_channels, out_channels, kernel_size)]

  def forward(self, x):
    x = self.block1(x)
    for block in self.blocks:
      x = block(x)

    return x


class R3DNet(nn.Module):
  r"""Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in
  each layer set by layer_sizes, and by performing a global average pool at the end producing a
  512-dimensional vector for each element in the batch.

      Args:
          layer_sizes (tuple): An iterable containing the number of blocks in each layer
          block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
  """

  def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock):
    super(R3DNet, self).__init__()

    # first conv, with stride 1x2x2 and kernel size 3x7x7
    self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[
                                    1, 2, 2], padding=[1, 3, 3])
    # output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3
    self.conv2 = SpatioTemporalResLayer(
        64, 64, 3, layer_sizes[0], block_type=block_type)
    # each of the final three layers doubles num_channels, while performing downsampling
    # inside the first block
    self.conv3 = SpatioTemporalResLayer(
        64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)
    self.conv4 = SpatioTemporalResLayer(
        128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)
    self.conv5 = SpatioTemporalResLayer(
        256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)

    # global average pooling of the output
    self.pool = nn.AdaptiveAvgPool3d(1)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)

    x = self.pool(x)

    return x.view(-1, 512)


class R3DClassifier(nn.Module):
  r"""Forms a complete ResNet classifier producing vectors of size num_classes, by initializng 5 layers,
  with the number of blocks in each layer set by layer_sizes, and by performing a global average pool
  at the end producing a 512-dimensional vector for each element in the batch,
  and passing them through a Linear layer.

      Args:
          num_classes(int): Number of classes in the data
          layer_sizes (tuple): An iterable containing the number of blocks in each layer
          block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
      """

  def __init__(self, num_classes, layer_sizes, block_type=SpatioTemporalResBlock, pretrained=False):
    super(R3DClassifier, self).__init__()

    self.res3d = R3DNet(layer_sizes, block_type)
    self.linear = nn.Linear(512, num_classes)

    self.__init_weight()

    if pretrained:
      self.__load_pretrained_weights()

  def forward(self, x):
    x = self.res3d(x)
    logits = self.linear(x)

    return logits

  def __load_pretrained_weights(self):
    s_dict = self.state_dict()
    for name in s_dict:
      print(name)
      print(s_dict[name].size())

  def __init_weight(self):
    for m in self.modules():
      if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight)
      elif isinstance(m, nn.BatchNorm3d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()


def get_1x_lr_params(model):
  """
  This generator returns all the parameters for the conv layer of the net.
  """
  b = [model.res3d]
  for i in range(len(b)):
    for k in b[i].parameters():
      if k.requires_grad:
        yield k


def get_10x_lr_params(model):
  """
  This generator returns all the parameters for the fc layer of the net.
  """
  b = [model.linear]
  for j in range(len(b)):
    for k in b[j].parameters():
      if k.requires_grad:
        yield k


def R3D18(num_classes, **kwargs):
  """Construct a R3D18 modelbased on a ResNet-18-3D model.
  """
  model = R3DClassifier(num_classes, [2, 2, 2, 2], **kwargs)
  return model


def R3D34(num_classes, **kwargs):
  """Construct a R3D34 modelbased on a ResNet-34-3D model.
  """
  model = R3DClassifier(num_classes, [3, 4, 6, 3], **kwargs)
  return model


if __name__ == "__main__":
  import torch
  inputs = torch.rand(1, 3, 16, 112, 112)
  net = R3DClassifier(101, (2, 2, 2, 2), pretrained=True)

  outputs = net.forward(inputs)
  print(outputs.size())
