Top

hypergan.samplers.began_sampler module

from hypergan.samplers.base_sampler import BaseSampler

import tensorflow as tf
import numpy as np

class BeganSampler(BaseSampler):
    def __init__(self, gan, samples_per_row):
        BaseSampler.__init__(self, gan, samples_per_row)
        self.x_v = None
        self.z_v = None
        self.created = False


    def _sample(self):
        gan = self.gan
        config = gan.config
        sess = gan.session
        x_t = gan.inputs.x
        z_t = gan.encoder.sample
        if(not self.created):
            self.x_v, self.z_v = sess.run([x_t, z_t])
            self.created=True
        g_t = gan.generator.sample
        rx_t = gan.discriminator.reconstruction
        rx_v, g_v = sess.run([rx_t, g_t], {x_t: self.x_v, z_t: self.z_v})
        stacks = []
        bs = gan.batch_size() // 2
        width = self.samples_per_row
        stacks.append(self.x_v)
        stacks.append(rx_v)
        stacks.append(g_v)

        #[print(np.shape(s)) for s in stacks]
        images = np.vstack(stacks)
        return { 'generator':images}


    def sample(self, path, save_samples=False):

        self.plot(images, path, save_samples)
        return [{'image': path, 'label': 'tiled x sample'}]

Classes

class BeganSampler

class BeganSampler(BaseSampler):
    def __init__(self, gan, samples_per_row):
        BaseSampler.__init__(self, gan, samples_per_row)
        self.x_v = None
        self.z_v = None
        self.created = False


    def _sample(self):
        gan = self.gan
        config = gan.config
        sess = gan.session
        x_t = gan.inputs.x
        z_t = gan.encoder.sample
        if(not self.created):
            self.x_v, self.z_v = sess.run([x_t, z_t])
            self.created=True
        g_t = gan.generator.sample
        rx_t = gan.discriminator.reconstruction
        rx_v, g_v = sess.run([rx_t, g_t], {x_t: self.x_v, z_t: self.z_v})
        stacks = []
        bs = gan.batch_size() // 2
        width = self.samples_per_row
        stacks.append(self.x_v)
        stacks.append(rx_v)
        stacks.append(g_v)

        #[print(np.shape(s)) for s in stacks]
        images = np.vstack(stacks)
        return { 'generator':images}


    def sample(self, path, save_samples=False):

        self.plot(images, path, save_samples)
        return [{'image': path, 'label': 'tiled x sample'}]

Ancestors (in MRO)

  • BeganSampler
  • hypergan.samplers.base_sampler.BaseSampler
  • builtins.object

Static methods

def __init__(

self, gan, samples_per_row)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, gan, samples_per_row):
    BaseSampler.__init__(self, gan, samples_per_row)
    self.x_v = None
    self.z_v = None
    self.created = False

def plot(

self, image, filename, save_sample)

Plot an image.

def plot(self, image, filename, save_sample):
    """ Plot an image."""
    image = np.minimum(image, 1)
    image = np.maximum(image, -1)
    image = np.squeeze(image)
    # Scale to 0..255.
    imin, imax = image.min(), image.max()
    image = (image - imin) * 255. / (imax - imin) + .5
    image = image.astype(np.uint8)
    if save_sample:
        try:
            Image.fromarray(image).save(filename)
        except Exception as e:
            print("Warning: could not sample to ", filename, ".  Please check permissions and make sure the path exists")
            print(e)
    GlobalViewer.update(image)

def sample(

self, path, save_samples=False)

def sample(self, path, save_samples=False):
    self.plot(images, path, save_samples)
    return [{'image': path, 'label': 'tiled x sample'}]

Instance variables

var created

var x_v

var z_v