hypergan.samplers.autoencode_sampler module
from hypergan.samplers.base_sampler import BaseSampler
import tensorflow as tf
import numpy as np
class AutoencodeSampler(BaseSampler):
def __init__(self, gan, samples_per_row=8):
BaseSampler.__init__(self, gan, samples_per_row)
self.z = None
self.y = None
self.x = None
def _sample(self):
gan = self.gan
inputs_t = gan.inputs.x
z_t = gan.encoder.sample
if self.z is None:
print("GAN IS", gan, gan.encoder)
self.input = gan.session.run(inputs_t)
self.z = gan.session.run(z_t, feed_dict={inputs_t: self.input})
destination = self.z[1]
origin = self.z[0]
for i in range(0, np.shape(self.z)[0], self.samples_per_row):
last = i+self.samples_per_row-1
multiple = np.linspace(0, 1, self.samples_per_row-4)
for j in range(i+2, last-1):
percent = (j - (i))/float((last) - (i+1))
self.z[j] = self.z[i]*(1.0-percent) + (self.z[last])*percent
self.z[i+1] = self.z[i]
self.z[last-1] = self.z[last]
output = gan.session.run(gan.generator.sample, feed_dict={z_t: self.z})
for i in range(0, np.shape(self.z)[0], self.samples_per_row):
last = i+self.samples_per_row-1
output[i] = self.input[i]
output[last] = self.input[last]
g=tf.get_default_graph()
with g.as_default():
tf.set_random_seed(1)
return {
'generator': output
}
Classes
class AutoencodeSampler
class AutoencodeSampler(BaseSampler):
def __init__(self, gan, samples_per_row=8):
BaseSampler.__init__(self, gan, samples_per_row)
self.z = None
self.y = None
self.x = None
def _sample(self):
gan = self.gan
inputs_t = gan.inputs.x
z_t = gan.encoder.sample
if self.z is None:
print("GAN IS", gan, gan.encoder)
self.input = gan.session.run(inputs_t)
self.z = gan.session.run(z_t, feed_dict={inputs_t: self.input})
destination = self.z[1]
origin = self.z[0]
for i in range(0, np.shape(self.z)[0], self.samples_per_row):
last = i+self.samples_per_row-1
multiple = np.linspace(0, 1, self.samples_per_row-4)
for j in range(i+2, last-1):
percent = (j - (i))/float((last) - (i+1))
self.z[j] = self.z[i]*(1.0-percent) + (self.z[last])*percent
self.z[i+1] = self.z[i]
self.z[last-1] = self.z[last]
output = gan.session.run(gan.generator.sample, feed_dict={z_t: self.z})
for i in range(0, np.shape(self.z)[0], self.samples_per_row):
last = i+self.samples_per_row-1
output[i] = self.input[i]
output[last] = self.input[last]
g=tf.get_default_graph()
with g.as_default():
tf.set_random_seed(1)
return {
'generator': output
}
Ancestors (in MRO)
- AutoencodeSampler
- hypergan.samplers.base_sampler.BaseSampler
- builtins.object
Static methods
def __init__(
self, gan, samples_per_row=8)
Initialize self. See help(type(self)) for accurate signature.
def __init__(self, gan, samples_per_row=8):
BaseSampler.__init__(self, gan, samples_per_row)
self.z = None
self.y = None
self.x = None
def plot(
self, image, filename, save_sample)
Plot an image.
def plot(self, image, filename, save_sample):
""" Plot an image."""
image = np.minimum(image, 1)
image = np.maximum(image, -1)
image = np.squeeze(image)
# Scale to 0..255.
imin, imax = image.min(), image.max()
image = (image - imin) * 255. / (imax - imin) + .5
image = image.astype(np.uint8)
if save_sample:
try:
Image.fromarray(image).save(filename)
except Exception as e:
print("Warning: could not sample to ", filename, ". Please check permissions and make sure the path exists")
print(e)
GlobalViewer.update(image)
def sample(
self, path, save_samples)
def sample(self, path, save_samples):
gan = self.gan
with gan.session.as_default():
sample = self._sample()
data = sample['generator']
width = min(gan.batch_size(), self.samples_per_row)
stacks = [np.hstack(data[i*width:i*width+width]) for i in range(gan.batch_size()//width)]
sample_data = np.vstack(stacks)
self.plot(sample_data, path, save_samples)
sample_name = 'generator'
samples = [[sample_data, sample_name]]
return [{'image':path, 'label':'sample'} for sample_data, sample_filename in samples]
Instance variables
var x
var y
var z