# Copyright (c) 2021, Google Inc.
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
# 
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
# 
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
# 
# 3. Neither the name of Google Inc. nor the names of its contributors
#    may be used to endorse or promote products derived from this software without
#    specific prior written permission.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Implementation of multiheaded attention and self-attention layers."""
import math
from typing import Any, Dict, Optional, Union, Iterable
import tensorflow as tf


class Attention(tf.keras.layers.Layer):
  """Multi-headed attention layer."""

  def __init__(
      self,
      hidden_size: int,
      num_heads: int,
      attention_dropout: float,
      attn_win_size: Optional[int] = None,
  ):
    """Initialize Attention.

    Args:
      hidden_size: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
      attn_win_size: Optional[int], local sliding window attention length. Use
        None for full attention.
    """
    if hidden_size % num_heads:
      raise ValueError(
          "Hidden size ({}) must be divisible by the number of heads ({})."
          .format(hidden_size, num_heads)
      )

    super(Attention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout
    self.attn_win_size = attn_win_size

  def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
    """Builds the layer."""
    # Layers for linearly projecting the queries, keys, and values.
    size_per_head = self.hidden_size // self.num_heads

    def _glorot_initializer(fan_in, fan_out):
      limit = math.sqrt(6.0 / (fan_in + fan_out))
      return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    input_hidden_size = input_shape.as_list()[-1]
    query_initializer = _glorot_initializer(input_hidden_size, self.hidden_size)
    key_initializer = _glorot_initializer(input_hidden_size, self.hidden_size)
    value_initializer = _glorot_initializer(input_hidden_size, self.hidden_size)
    self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
        kernel_initializer=query_initializer,
        bias_axes=None,
        name="query",
    )
    self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
        kernel_initializer=key_initializer,
        bias_axes=None,
        name="key",
    )
    self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTE,ENH->BTNH",
        output_shape=(None, self.num_heads, size_per_head),
        kernel_initializer=value_initializer,
        bias_axes=None,
        name="value",
    )

    output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
    self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
        "BTNH,NHE->BTE",
        output_shape=(None, self.hidden_size),
        kernel_initializer=output_initializer,
        bias_axes=None,
        name="output_transform",
    )

    # input_shape = [batch_size, input_length, hidden_size]
    max_length = input_shape.as_list()[1]

    if self.attn_win_size:
      self.attn_mask = tf.ones([1, 1, max_length, max_length])
      self.attn_mask = tf.linalg.band_part(
          self.attn_mask, self.attn_win_size, self.attn_win_size
      )
      # attn_mask will contain True values in the band and False values outside.
      self.attn_mask = self.attn_mask > 0.0
    else:
      self.attn_mask = tf.ones([1, 1, max_length, max_length]) > 0.0

    super(Attention, self).build(input_shape)

  def get_config(self) -> Dict[str, Any]:
    return {
        "hidden_size": self.hidden_size,
        "num_heads": self.num_heads,
        "attention_dropout": self.attention_dropout,
        "attn_win_size": self.attn_win_size,
    }

  def call(
      self,
      query_input: tf.Tensor,
      source_input: tf.Tensor,
      bias: tf.Tensor,
      training: bool,
      cache: Optional[Dict[str, tf.Tensor]] = None,
      decode_loop_step: Optional[int] = None,
  ) -> Dict[str, tf.Tensor]:
    """Apply attention mechanism to query_input and source_input.

    Args:
      query_input: A tensor with shape [batch_size, length_query, hidden_size].
      source_input: A tensor with shape [batch_size, length_source,
        hidden_size].
      bias: A tensor with shape [batch_size, 1, length_query, length_source],
        the attention bias that will be added to the result of the dot product.
      training: A bool, whether in training mode or not.
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
        {"k": tensor with shape [batch_size, i, heads, dim_per_head], "v":
        tensor with shape [batch_size, i, heads, dim_per_head]} where i is the
        current decoded length for non-padded decode, or max sequence length for
        padded decode.
      decode_loop_step: An integer, step number of the decoding loop. Used only
        for autoregressive inference on TPU.

    Returns:
      Dictionary with the following (key:value) pairs:
        "main_output": Attention layer output with shape [batch_size,
        length_query, hidden_size]. Used as input to the feed_forward_network.
        "attention scores": Attention map weights (after softmax) with shape
        [batch_size, num_heads, length_query, length_query] - auxiliary output.
    """
    # Linearly project the query, key and value using different learned
    # projections. Splitting heads is automatically done during the linear
    # projections --> [batch_size, length, num_heads, dim_per_head].
    query = self.query_dense_layer(query_input)
    key = self.key_dense_layer(source_input)
    value = self.value_dense_layer(source_input)

    if cache is not None:
      # Combine cached keys and values with new keys and values.
      if decode_loop_step is not None:
        cache_k_shape = cache["k"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
            [1, cache_k_shape[1], 1, 1],
        )
        key = cache["k"] + key * indices
        cache_v_shape = cache["v"].shape.as_list()
        indices = tf.reshape(
            tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
            [1, cache_v_shape[1], 1, 1],
        )
        value = cache["v"] + value * indices
      else:
        key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
        value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)

      # Update cache
      cache["k"] = key
      cache["v"] = value

    # Scale query to prevent the dot product between query and key from growing
    # too large.
    depth = self.hidden_size // self.num_heads
    query *= depth**-0.5

    # Calculate dot product attention
    logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
    logits += bias
    # False values in the mask will be set to a large negative number in the
    # logits. The attention scores for elements outside the band will be close
    # to 0 after softmax.
    logits = tf.where(self.attn_mask, logits, -1e9)
    # Note that softmax internally performs math operations using float32
    # for numeric stability. When training with float16, we keep the input
    # and output in float16 for better performance.
    weights = tf.nn.softmax(logits, name="attention_weights")
    if training:
      weights = tf.nn.dropout(weights, rate=self.attention_dropout)
    attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)

    # Run the outputs through another linear projection layer. Recombining heads
    # is automatically done --> [batch_size, length, hidden_size]
    attention_output = self.output_dense_layer(attention_output)

    layer_output = dict(main_output=attention_output, attention_scores=weights)
    return layer_output


class SelfAttention(Attention):
  """Multiheaded self-attention layer."""

  def call(
      self,
      query_input: tf.Tensor,
      bias: tf.Tensor,
      training: bool,
      cache: Optional[Dict[str, tf.Tensor]] = None,
      decode_loop_step: Optional[int] = None,
  ) -> Dict[str, tf.Tensor]:
    return super(SelfAttention, self).call(
        query_input, query_input, bias, training, cache, decode_loop_step
    )
