# Copyright 2018 The KaiJIN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ShuffleNet v2 form torchvision.model """
import torch
import torch.nn as nn
from tw.utils.checkpoint import load_state_dict_from_url

model_urls = {
    'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
    'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
    'shufflenetv2_x1.5': None,
    'shufflenetv2_x2.0': None,
}


def channel_shuffle(x, groups):
  batchsize, num_channels, height, width = x.data.size()
  channels_per_group = num_channels // groups

  # reshape
  x = x.view(batchsize, groups,
             channels_per_group, height, width)

  x = torch.transpose(x, 1, 2).contiguous()

  # flatten
  x = x.view(batchsize, -1, height, width)

  return x


class InvertedResidual(nn.Module):
  def __init__(self, inp, oup, stride):
    super(InvertedResidual, self).__init__()

    if not (1 <= stride <= 3):
      raise ValueError('illegal stride value')
    self.stride = stride

    branch_features = oup // 2
    assert (self.stride != 1) or (inp == branch_features << 1)

    if self.stride > 1:
      self.branch1 = nn.Sequential(
          self.depthwise_conv(inp, inp, kernel_size=3,
                              stride=self.stride, padding=1),
          nn.BatchNorm2d(inp),
          nn.Conv2d(inp, branch_features, kernel_size=1,
                    stride=1, padding=0, bias=False),
          nn.BatchNorm2d(branch_features),
          nn.ReLU(inplace=True),
      )

    self.branch2 = nn.Sequential(
        nn.Conv2d(inp if (self.stride > 1) else branch_features,
                  branch_features, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(branch_features),
        nn.ReLU(inplace=True),
        self.depthwise_conv(branch_features, branch_features,
                            kernel_size=3, stride=self.stride, padding=1),
        nn.BatchNorm2d(branch_features),
        nn.Conv2d(branch_features, branch_features,
                  kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(branch_features),
        nn.ReLU(inplace=True),
    )

  @staticmethod
  def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
    return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

  def forward(self, x):
    if self.stride == 1:
      x1, x2 = x.chunk(2, dim=1)
      out = torch.cat((x1, self.branch2(x2)), dim=1)
    else:
      out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

    out = channel_shuffle(out, 2)

    return out


class ShuffleNetV2(nn.Module):

  MEAN = [0.485, 0.456, 0.406]
  STD = [0.229, 0.224, 0.225]
  SIZE = [224, 224]
  SCALE = 255
  CROP = 0.875

  def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
    super(ShuffleNetV2, self).__init__()

    if len(stages_repeats) != 3:
      raise ValueError('expected stages_repeats as list of 3 positive ints')
    if len(stages_out_channels) != 5:
      raise ValueError(
          'expected stages_out_channels as list of 5 positive ints')
    self._stage_out_channels = stages_out_channels

    input_channels = 3
    output_channels = self._stage_out_channels[0]
    self.conv1 = nn.Sequential(
        nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True),
    )
    input_channels = output_channels

    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
    for name, repeats, output_channels in zip(
            stage_names, stages_repeats, self._stage_out_channels[1:]):
      seq = [InvertedResidual(input_channels, output_channels, 2)]
      for i in range(repeats - 1):
        seq.append(InvertedResidual(output_channels, output_channels, 1))
      setattr(self, name, nn.Sequential(*seq))
      input_channels = output_channels

    output_channels = self._stage_out_channels[-1]
    self.conv5 = nn.Sequential(
        nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True),
    )

    self.fc = nn.Linear(output_channels, num_classes)

  def forward(self, x):
    x = self.conv1(x)
    x = self.maxpool(x)
    x = self.stage2(x)
    x = self.stage3(x)
    x = self.stage4(x)
    x = self.conv5(x)
    x = x.mean([2, 3])  # globalpool
    x = self.fc(x)
    return x


def _shufflenetv2(arch, pretrained, *args, **kwargs):
  model = ShuffleNetV2(*args, **kwargs)
  if pretrained:
    model_url = model_urls[arch]
    if model_url is None:
      raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
    else:
      load_state_dict_from_url(model, model_url)
  return model


def shufflenet_v2_x0_5(pretrained=False, **kwargs):
  """
  Constructs a ShuffleNetV2 with 0.5x output channels, as described in
  `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
  <https://arxiv.org/abs/1807.11164>`_.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _shufflenetv2('shufflenetv2_x0.5', pretrained, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)


def shufflenet_v2_x1_0(pretrained=False, **kwargs):
  """
  Constructs a ShuffleNetV2 with 1.0x output channels, as described in
  `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
  <https://arxiv.org/abs/1807.11164>`_.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _shufflenetv2('shufflenetv2_x1.0', pretrained, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)


def shufflenet_v2_x1_5(pretrained=False, **kwargs):
  """
  Constructs a ShuffleNetV2 with 1.5x output channels, as described in
  `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
  <https://arxiv.org/abs/1807.11164>`_.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _shufflenetv2('shufflenetv2_x1.5', pretrained, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)


def shufflenet_v2_x2_0(pretrained=False, **kwargs):
  """
  Constructs a ShuffleNetV2 with 2.0x output channels, as described in
  `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
  <https://arxiv.org/abs/1807.11164>`_.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _shufflenetv2('shufflenetv2_x2.0', pretrained, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
