Top

hypergan.samplers.began_sampler module

from hypergan.samplers.base_sampler import BaseSampler

import tensorflow as tf
import numpy as np

class BeganSampler(BaseSampler):
    def __init__(self, gan):
        BaseSampler.__init__(self, gan)
        self.x_v = None
        self.z_v = None
        self.created = False

    def sample(self, path, save_samples=False):
        gan = self.gan
        x_t = gan.inputs.x
        g_t = gan.generator.sample
        z_t = gan.encoder.sample
        
        rx_t = gan.discriminator.reconstruction

        sess = gan.session
        config = gan.config
        if(not self.created):
            self.x_v, self.z_v = sess.run([x_t, z_t])
            self.created=True

        rx_v, g_v = sess.run([rx_t, g_t], {x_t: self.x_v, z_t: self.z_v})
        stacks = []
        bs = gan.batch_size() // 2
        width = 8
        for i in range(1):
            stacks.append([self.x_v[i*width+width+j] for j in range(width)])
        for i in range(1):
            stacks.append([rx_v[i*width+width+j] for j in range(width)])
        for i in range(1):
            stacks.append([g_v[i*width+width+j] for j in range(width)])

        #[print(np.shape(s)) for s in stacks]
        images = np.vstack([np.hstack(s) for s in stacks])

        self.plot(images, path, save_samples)
        return [{'image': path, 'label': 'tiled x sample'}]

Classes

class BeganSampler

class BeganSampler(BaseSampler):
    def __init__(self, gan):
        BaseSampler.__init__(self, gan)
        self.x_v = None
        self.z_v = None
        self.created = False

    def sample(self, path, save_samples=False):
        gan = self.gan
        x_t = gan.inputs.x
        g_t = gan.generator.sample
        z_t = gan.encoder.sample
        
        rx_t = gan.discriminator.reconstruction

        sess = gan.session
        config = gan.config
        if(not self.created):
            self.x_v, self.z_v = sess.run([x_t, z_t])
            self.created=True

        rx_v, g_v = sess.run([rx_t, g_t], {x_t: self.x_v, z_t: self.z_v})
        stacks = []
        bs = gan.batch_size() // 2
        width = 8
        for i in range(1):
            stacks.append([self.x_v[i*width+width+j] for j in range(width)])
        for i in range(1):
            stacks.append([rx_v[i*width+width+j] for j in range(width)])
        for i in range(1):
            stacks.append([g_v[i*width+width+j] for j in range(width)])

        #[print(np.shape(s)) for s in stacks]
        images = np.vstack([np.hstack(s) for s in stacks])

        self.plot(images, path, save_samples)
        return [{'image': path, 'label': 'tiled x sample'}]

Ancestors (in MRO)

  • BeganSampler
  • hypergan.samplers.base_sampler.BaseSampler
  • builtins.object

Static methods

def __init__(

self, gan)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, gan):
    BaseSampler.__init__(self, gan)
    self.x_v = None
    self.z_v = None
    self.created = False

def plot(

self, image, filename, save_sample)

Plot an image.

def plot(self, image, filename, save_sample):
    """ Plot an image."""
    image = np.minimum(image, 1)
    image = np.maximum(image, -1)
    image = np.squeeze(image)
    # Scale to 0..255.
    imin, imax = image.min(), image.max()
    image = (image - imin) * 255. / (imax - imin) + .5
    image = image.astype(np.uint8)
    if save_sample:
        try:
            Image.fromarray(image).save(filename)
        except Exception as e:
            print("Warning: could not sample to ", filename, ".  Please check permissions and make sure the path exists")
            print(e)
    GlobalViewer.update(image)

def sample(

self, path, save_samples=False)

def sample(self, path, save_samples=False):
    gan = self.gan
    x_t = gan.inputs.x
    g_t = gan.generator.sample
    z_t = gan.encoder.sample
    
    rx_t = gan.discriminator.reconstruction
    sess = gan.session
    config = gan.config
    if(not self.created):
        self.x_v, self.z_v = sess.run([x_t, z_t])
        self.created=True
    rx_v, g_v = sess.run([rx_t, g_t], {x_t: self.x_v, z_t: self.z_v})
    stacks = []
    bs = gan.batch_size() // 2
    width = 8
    for i in range(1):
        stacks.append([self.x_v[i*width+width+j] for j in range(width)])
    for i in range(1):
        stacks.append([rx_v[i*width+width+j] for j in range(width)])
    for i in range(1):
        stacks.append([g_v[i*width+width+j] for j in range(width)])
    #[print(np.shape(s)) for s in stacks]
    images = np.vstack([np.hstack(s) for s in stacks])
    self.plot(images, path, save_samples)
    return [{'image': path, 'label': 'tiled x sample'}]

Instance variables

var created

var x_v

var z_v