from __future__ import print_function, division, absolute_import

from transformer_scores import ScoreFormer
from tensor2tensor.models.transformer import Transformer
from tensor2tensor.utils import registry, trainer_lib
from tensor2tensor.data_generators.text_problems import Text2SelfProblem
from twitter import GigaTwitter
import tensorflow as tf
from boto3 import s3
from collections import namedtuple
import os
import io
import glob

def with_params(params, fn):
    hparams = registry.hparams(params['hparams'])()
    problem = GigaTwitter()
    hparams.add_hparam("data_dir", "")
    return fn(problem, hparams)

def model_fn(features, labels, mode, params):
    model_class = ScoreFormer if mode == 'PREDICT' else Transformer
    return with_params(params, lambda problem, hparams: 
                        model_class.estimator_model_fn(hparams, features, labels, mode))


def train_input_fn(training_dir, params):
    return with_params(params, lambda problem, hparams: 
                        problem.input_fn(tf.estimator.ModeKeys.TRAIN, hparams, training_dir, params))

def eval_input_fn(training_dir, params):
    return with_params(params, lambda problem, hparams: 
                        problem.input_fn(tf.estimator.ModeKeys.EVAL, hparams, training_dir, params))

def serving_input_fn(params):
    return with_params(params, lambda problem, hparams: 
                        problem.serving_input_fn(hparams))
    
