Top

hypergan.samplers.autoencode_sampler module

from hypergan.samplers.base_sampler import BaseSampler
import tensorflow as tf
import numpy as np

class AutoencodeSampler(BaseSampler):
    def __init__(self, gan, samples_per_row=8):
        BaseSampler.__init__(self, gan, samples_per_row)
        self.z = None
        self.y = None
        self.x = None

    def _sample(self):
        gan = self.gan
        inputs_t = gan.inputs.x
        z_t = gan.encoder.sample

        if self.z is None:
            print("GAN IS", gan, gan.encoder)
            self.input = gan.session.run(inputs_t)
        self.z = gan.session.run(z_t, feed_dict={inputs_t: self.input})

        destination = self.z[1]
        origin = self.z[0]
        for i in range(0, np.shape(self.z)[0], self.samples_per_row):
            last = i+self.samples_per_row-1
            multiple = np.linspace(0, 1, self.samples_per_row-4)

            for j in range(i+2, last-1):
                percent = (j - (i))/float((last) - (i+1))
                self.z[j] = self.z[i]*(1.0-percent) + (self.z[last])*percent
            self.z[i+1] = self.z[i]
            self.z[last-1] = self.z[last]
 
        output = gan.session.run(gan.generator.sample, feed_dict={z_t: self.z})
        for i in range(0, np.shape(self.z)[0], self.samples_per_row):
            last = i+self.samples_per_row-1
            output[i] = self.input[i]
            output[last] = self.input[last] 
        

        g=tf.get_default_graph()
        with g.as_default():
            tf.set_random_seed(1)
            return {
                'generator': output
            }

Classes

class AutoencodeSampler

class AutoencodeSampler(BaseSampler):
    def __init__(self, gan, samples_per_row=8):
        BaseSampler.__init__(self, gan, samples_per_row)
        self.z = None
        self.y = None
        self.x = None

    def _sample(self):
        gan = self.gan
        inputs_t = gan.inputs.x
        z_t = gan.encoder.sample

        if self.z is None:
            print("GAN IS", gan, gan.encoder)
            self.input = gan.session.run(inputs_t)
        self.z = gan.session.run(z_t, feed_dict={inputs_t: self.input})

        destination = self.z[1]
        origin = self.z[0]
        for i in range(0, np.shape(self.z)[0], self.samples_per_row):
            last = i+self.samples_per_row-1
            multiple = np.linspace(0, 1, self.samples_per_row-4)

            for j in range(i+2, last-1):
                percent = (j - (i))/float((last) - (i+1))
                self.z[j] = self.z[i]*(1.0-percent) + (self.z[last])*percent
            self.z[i+1] = self.z[i]
            self.z[last-1] = self.z[last]
 
        output = gan.session.run(gan.generator.sample, feed_dict={z_t: self.z})
        for i in range(0, np.shape(self.z)[0], self.samples_per_row):
            last = i+self.samples_per_row-1
            output[i] = self.input[i]
            output[last] = self.input[last] 
        

        g=tf.get_default_graph()
        with g.as_default():
            tf.set_random_seed(1)
            return {
                'generator': output
            }

Ancestors (in MRO)

Static methods

def __init__(

self, gan, samples_per_row=8)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, gan, samples_per_row=8):
    BaseSampler.__init__(self, gan, samples_per_row)
    self.z = None
    self.y = None
    self.x = None

def plot(

self, image, filename, save_sample)

Plot an image.

def plot(self, image, filename, save_sample):
    """ Plot an image."""
    image = np.minimum(image, 1)
    image = np.maximum(image, -1)
    image = np.squeeze(image)
    # Scale to 0..255.
    imin, imax = image.min(), image.max()
    image = (image - imin) * 255. / (imax - imin) + .5
    image = image.astype(np.uint8)
    if save_sample:
        try:
            Image.fromarray(image).save(filename)
        except Exception as e:
            print("Warning: could not sample to ", filename, ".  Please check permissions and make sure the path exists")
            print(e)
    GlobalViewer.update(image)

def sample(

self, path, save_samples)

def sample(self, path, save_samples):
    gan = self.gan
    with gan.session.as_default():
        sample = self._sample()
        data = sample['generator']
        width = min(gan.batch_size(), self.samples_per_row)
        stacks = [np.hstack(data[i*width:i*width+width]) for i in range(gan.batch_size()//width)]
        sample_data = np.vstack(stacks)
        self.plot(sample_data, path, save_samples)
        sample_name = 'generator'
        samples = [[sample_data, sample_name]]
        return [{'image':path, 'label':'sample'} for sample_data, sample_filename in samples]

Instance variables

var x

var y

var z