import io
from os import path
import os
import gensim
import numpy as np
import csv
import tensorflow as tf
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators.text_problems import Text2ClassProblem, VocabType
from tensor2tensor.layers.modalities import SymbolModality
from tensor2tensor.utils import registry

emotions = 'anger disgust fear joy sad surprise'.split()


@registry.register_problem
class WassaSimple(Text2ClassProblem):
    @property
    def use_cleaned(self):
        return False
    @property
    def num_classes(self):
        return len(emotions)

    def class_labels(self, data_dir):
        return emotions

    @property
    def is_generate_per_split(self):
        return True

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        emotion_ix = { emotion: ix for ix, emotion in enumerate(emotions)}
        if dataset_split == problem.DatasetSplit.TRAIN:
            filename = "train" 
        else:
            filename = "trial"
        
        filename += "-v3"

        if self.use_cleaned:
            filename += "_cleaned"
            sep = ','
        else:
            sep = '\t'
        filename += ".csv"
        with io.open(path.join(tmp_dir, filename)) as f:
            for label, text in csv.reader(f):
                yield { 'inputs': text, 'label': emotion_ix[label] }


NPY_FILENAME = "embedding.npy"

@registry.register_symbol_modality('pretrained')
class PretrainedSymbolModality(SymbolModality):
    def _get_weights(self, hidden_dim=None):
        assert self._model_hparams.symbol_modality_num_shards == 1
        npy_file = path.join(self._model_hparams.data_dir, NPY_FILENAME)
        weights = np.load(npy_file, 'r')
        assert hidden_dim == None or hidden_dim == weights.shape[1]
        return tf.get_variable('embedding', shape=weights.shape, trainable=False, initializer=tf.initializers.constant(weights))

@registry.register_problem
class WassaEmbedding(WassaSimple):
    @property
    def use_cleaned(self):
        return True
    @property
    def oov_token(self):
        return '_unk_'

    def hparams(self, hp, model_hparams):
        super(WassaEmbedding, self).hparams(hp, model_hparams)
        hp.input_modality['inputs'] = (registry.Modalities.SYMBOL + ":pretrained", hp.input_modality['inputs'][1])

    @property
    def vocab_type(self):
        return VocabType.TOKEN
    
    def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False):
        vocabfile = path.join(data_dir, self.vocab_filename)

        if not os.access(vocabfile, os.F_OK):
            embed = gensim.models.KeyedVectors.load_word2vec_format(path.join(tmp_dir, "plumbus_embedding_304"), binary=False)
            with io.open(vocabfile, mode='w', errors='ignore') as f:
                for word in embed.index2entity:
                    f.write(word + "\n")
            np.save(path.join(data_dir, NPY_FILENAME), embed.vectors)
        return super(WassaEmbedding, self).get_or_create_vocab(data_dir, tmp_dir, force_get)

