from tensor2tensor.models.transformer import Transformer, features_to_nonpadding
from  tensor2tensor.utils import beam_search, registry
from tensor2tensor.layers import common_layers, common_attention
from tensorflow.python.util import nest
import tensorflow as tf

@registry.register_model
class ScoreFormer(Transformer):
    """Transformer variant that computes the log-probability of a sentence
    """

    def __init__(self, *args, **kwargs):
        super(ScoreFormer, self).__init__(*args, **kwargs)
        self._name = "transformer"


        self._base_name = "transformer"

    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """Fast decoding.

        Implements both greedy and beam search decoding, uses beam search iff
        beam_size > 1, otherwise beam search related arguments are ignored.

        Args:
        features: a map of string to model  features.
        decode_length: an integer.  How many additional timesteps to decode.
        beam_size: number of beams.
        top_beams: an integer. How many of the beams to return.
        alpha: Float that controls the length penalty. larger the alpha, stronger
            the preference for longer translations.

        Returns:
        A dict of decoding results {
            "outputs": integer `Tensor` of decoded ids of shape
                [batch_size, <= decode_length] if beam_size == 1 or
                [batch_size, top_beams, <= decode_length]
            "scores": decoding log probs from the beam search,
                None if using greedy decoding (beam_size=1)
        }

        Raises:
        NotImplementedError: If there are multiple data shards.
        """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality

        if self.has_input:
            inputs = features["inputs"]
            targets = features.get("targets")
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = common_layers.shape_list(inputs)[1] + decode_length

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode, inputs, features["target_space_id"], hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            encoder_output = None
            encoder_decoder_attention_bias = None

            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features["targets"]
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)

            targets = partial_targets
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length += partial_targets_length
            batch_size = partial_targets_shape[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

            This includes:
            - Embedding the ids.
            - Flattening to 3D tensor.
            - Optionally adding timing signals.

            Args:
            targets: inputs ids to the decoder. [batch_size, 1]
            i: scalar, Step number of the decoding loop.

            Returns:
            Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets, dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(
                tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(
                    self.decode, targets, cache.get("encoder_output"),
                    cache.get("encoder_decoder_attention_bias"),
                    bias, hparams, cache,
                    nonpadding=features_to_nonpadding(features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            return ret, cache

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            targets=targets,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size)
        return ret


def fast_decode(encoder_output,
                encoder_decoder_attention_bias,
                targets,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None):
    """Given encoder output and a symbols to logits function, does fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      encoder_output: Output from encoder.
      encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
        attention
      symbols_to_logits_fn: Incremental decoding; function mapping triple
        `(ids, step, cache)` to symbol logits.
      hparams: run hyperparameters
      decode_length: an integer.  How many additional timesteps to decode.
      vocab_size: Output vocabulary size.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.
      eos_id: End-of-sequence symbol in beam search.
      batch_size: an integer scalar - must be passed if there is no input

    Returns:
        A dict of decoding results {
            "outputs": integer `Tensor` of decoded ids of shape
                [batch_size, <= decode_length] if top_beams == 1 or
                [batch_size, top_beams, <= decode_length] otherwise
            "scores": decoding log probs from the beam search,
                None if using greedy decoding (beam_size=1)
        }

      Raises:
        NotImplementedError: If beam size > 1 with partial targets.
    """
    if encoder_output is not None:
        batch_size = common_layers.shape_list(encoder_output)[0]

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    if encoder_output is not None:
        cache["encoder_output"] = encoder_output
        cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    def inner_loop(i, finished, next_id, cache, log_prob):
        """One step of greedy decoding."""
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        log_probs = beam_search.log_prob_from_logits(logits)

        next_id = targets[:,i]
        finished |= tf.equal(next_id, eos_id)

        log_prob_indices = tf.stack(
            [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
        log_prob += tf.where(finished, tf.zeros_like(log_prob),
                             (tf.gather_nd(log_probs, log_prob_indices)))

        next_id = tf.expand_dims(next_id, axis=1)
        return i + 1, finished, next_id, cache, log_prob

    def is_not_finished(i, finished, *_):
        return (i < decode_length) & tf.logical_not(tf.reduce_all(finished))

    finished = tf.fill([batch_size], False)
    next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
    initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
    _, _, _, _, log_prob = tf.while_loop(
        is_not_finished,
        inner_loop, [
            tf.constant(0), finished, next_id, cache,
            initial_log_prob
        ],
        shape_invariants=[
            tf.TensorShape([]),
            tf.TensorShape([None]),
            tf.TensorShape([None, None]),
            nest.map_structure(
                beam_search.get_state_shape_invariants, cache),
            tf.TensorShape([None]),
        ])
    scores = log_prob

    return {"outputs": targets, "scores": scores}


@registry.register_model
class TransformerFnord(Transformer):
    def body():
        
        return super(TransformerFnord, self).body(), {'training': 0.0}