# Copyright 2021 The KaiJIN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""spynet: optical flow network
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import tw

from mmcv.cnn import ConvModule


def flow_warp(x, flow, interpolation='bilinear', padding_mode='zeros', align_corners=True):
  """Warp an image or a feature map with optical flow.

  Args:
      x (Tensor): Tensor with size (n, c, h, w).
      flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
          a two-channel, denoting the width and height relative offsets.
          Note that the values are not normalized to [-1, 1].
      interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
          Default: 'bilinear'.
      padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
          Default: 'zeros'.
      align_corners (bool): Whether align corners. Default: True.

  Returns:
      Tensor: Warped image or feature map.
  """
  if x.size()[-2:] != flow.size()[1:3]:
    raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                     f'flow ({flow.size()[1:3]}) are not the same.')
  _, _, h, w = x.size()
  # create mesh grid
  grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
  grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (h, w, 2)
  grid.requires_grad = False

  grid_flow = grid + flow
  # scale grid_flow to [-1,1]
  grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
  grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
  grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
  output = F.grid_sample(x, grid_flow, mode=interpolation, padding_mode=padding_mode, align_corners=align_corners)
  return output


class SPyNet(nn.Module):
  """SPyNet network structure.

  The difference to the SPyNet in [tof.py] is that
      1. more SPyNetBasicModule is used in this version, and
      2. no batch normalization is used in this version.

  Paper:
      Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017

  Args:
      pretrained (str): path for pre-trained SPyNet. Default: None.
  """

  def __init__(self, pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'):
    super().__init__()
    self.basic_module = nn.ModuleList([SPyNetBasicModule() for _ in range(6)])

    if isinstance(pretrained, str):
      ckpt = tw.checkpoint.load(pretrained)
      tw.checkpoint.load_matched_state_dict(self, ckpt)
    elif pretrained is not None:
      raise TypeError('[pretrained] should be str or None, but got {type(pretrained)}.')

    self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
    self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

  def compute_flow(self, ref, supp):
    """Compute flow from ref to supp.

    Note that in this function, the images are already resized to a
    multiple of 32.

    Args:
        ref (Tensor): Reference image with shape of (n, 3, h, w).
        supp (Tensor): Supporting image with shape of (n, 3, h, w).

    Returns:
        Tensor: Estimated optical flow: (n, 2, h, w).
    """
    n, _, h, w = ref.size()

    # normalize the input images
    ref = [(ref - self.mean) / self.std]
    supp = [(supp - self.mean) / self.std]

    # generate downsampled frames
    for level in range(5):
      ref.append(F.avg_pool2d(input=ref[-1], kernel_size=2, stride=2, count_include_pad=False))
      supp.append(F.avg_pool2d(input=supp[-1], kernel_size=2, stride=2, count_include_pad=False))
    ref = ref[::-1]
    supp = supp[::-1]

    # flow computation
    flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
    for level in range(len(ref)):
      if level == 0:
        flow_up = flow
      else:
        flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0

      # add the residue to the upsampled flow
      flow = flow_up + self.basic_module[level](torch.cat([ref[level], flow_warp(supp[level], flow_up.permute(0, 2, 3, 1), padding_mode='border'), flow_up], 1))  # nopep8

    return flow

  def forward(self, ref, supp):
    """Forward function of SPyNet.

    This function computes the optical flow from ref to supp.

    Args:
        ref (Tensor): Reference image with shape of (n, 3, h, w).
        supp (Tensor): Supporting image with shape of (n, 3, h, w).

    Returns:
        Tensor: Estimated optical flow: (n, 2, h, w).
    """

    # upsize to a multiple of 32
    h, w = ref.shape[2:4]
    w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
    h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
    ref = F.interpolate(input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
    supp = F.interpolate(input=supp, size=(h_up, w_up), mode='bilinear', align_corners=False)

    # compute flow, and resize back to the original resolution
    flow = F.interpolate(input=self.compute_flow(ref, supp), size=(h, w), mode='bilinear', align_corners=False)

    # adjust the flow values
    flow[:, 0, :, :] *= float(w) / float(w_up)
    flow[:, 1, :, :] *= float(h) / float(h_up)

    return flow


class SPyNetBasicModule(nn.Module):
  """Basic Module for SPyNet.

  Paper:
      Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
  """

  def __init__(self):
    super().__init__()

    self.basic_module = nn.Sequential(
        ConvModule(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, norm_cfg=None, act_cfg=dict(type='ReLU')),  # nopep8
        ConvModule(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, norm_cfg=None, act_cfg=dict(type='ReLU')),  # nopep8
        ConvModule(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, norm_cfg=None, act_cfg=dict(type='ReLU')),  # nopep8
        ConvModule(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, norm_cfg=None, act_cfg=dict(type='ReLU')),  # nopep8
        ConvModule(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3, norm_cfg=None, act_cfg=None))  # nopep8

  def forward(self, tensor_input):
    """
    Args:
        tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
            8 channels contain:
            [reference image (3), neighbor image (3), initial flow (2)].

    Returns:
        Tensor: Refined flow with shape (b, 2, h, w)
    """
    return self.basic_module(tensor_input)


if __name__ == "__main__":
  model = SPyNet()
  model.eval()
  with torch.no_grad():
    out = model(torch.rand(1, 3, 256, 256), torch.rand(1, 3, 256, 256))
    print(out.shape)
