# Copyright 2021 The KaiJIN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
from math import ceil
from itertools import product as product
import torch


def nms_(dets, thresh):
  """
  Courtesy of Ross Girshick
  [https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py]
  """
  x1 = dets[:, 0]
  y1 = dets[:, 1]
  x2 = dets[:, 2]
  y2 = dets[:, 3]
  scores = dets[:, 4]

  areas = (x2 - x1) * (y2 - y1)
  order = scores.argsort()[::-1]

  keep = []
  while order.size > 0:
    i = order[0]
    keep.append(int(i))
    xx1 = np.maximum(x1[i], x1[order[1:]])
    yy1 = np.maximum(y1[i], y1[order[1:]])
    xx2 = np.minimum(x2[i], x2[order[1:]])
    yy2 = np.minimum(y2[i], y2[order[1:]])

    w = np.maximum(0.0, xx2 - xx1)
    h = np.maximum(0.0, yy2 - yy1)
    inter = w * h
    ovr = inter / (areas[i] + areas[order[1:]] - inter)

    inds = np.where(ovr <= thresh)[0]
    order = order[inds + 1]

  return np.array(keep).astype(np.int)


def decode(loc, priors, variances):
  """Decode locations from predictions using priors to undo
  the encoding we did for offset regression at train time.
  Args:
      loc (tensor): location predictions for loc layers,
          Shape: [num_priors,4]
      priors (tensor): Prior boxes in center-offset form.
          Shape: [num_priors,4].
      variances: (list[float]) Variances of priorboxes
  Return:
      decoded bounding box predictions
  """

  boxes = torch.cat((
      priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
      priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
  boxes[:, :2] -= boxes[:, 2:] / 2
  boxes[:, 2:] += boxes[:, :2]
  return boxes


class PriorBox(object):
  def __init__(self, image_size=None):
    super(PriorBox, self).__init__()
    self.min_sizes = [[32, 64, 128], [256], [512]]
    self.steps = [32, 64, 128]
    self.clip = False
    self.image_size = image_size
    self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]

  def forward(self):
    anchors = []
    for k, f in enumerate(self.feature_maps):
      min_sizes = self.min_sizes[k]
      for i, j in product(range(f[0]), range(f[1])):
        for min_size in min_sizes:
          s_kx = min_size / self.image_size[1]
          s_ky = min_size / self.image_size[0]
          if min_size == 32:
            dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.25, j + 0.5, j + 0.75]]
            dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.25, i + 0.5, i + 0.75]]
            for cy, cx in product(dense_cy, dense_cx):
              anchors += [cx, cy, s_kx, s_ky]
          elif min_size == 64:
            dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.5]]
            dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.5]]
            for cy, cx in product(dense_cy, dense_cx):
              anchors += [cx, cy, s_kx, s_ky]
          else:
            cx = (j + 0.5) * self.steps[k] / self.image_size[1]
            cy = (i + 0.5) * self.steps[k] / self.image_size[0]
            anchors += [cx, cy, s_kx, s_ky]
    # back to torch land
    output = torch.Tensor(anchors).view(-1, 4)
    if self.clip:
      output.clamp_(max=1, min=0)
    return output
