hypergan.generators.resize_conv_generator module
import tensorflow as tf
import numpy as np
import hyperchamber as hc
from hypergan.generators.common import *
from .base_generator import BaseGenerator
class ResizeConvGenerator(BaseGenerator):
def required(self):
return "final_depth activation depth_increase".split()
def depths(self, initial_width=4):
gan = self.gan
ops = self.ops
config = self.config
final_depth = config.final_depth-config.depth_increase
depths = []
target_w = gan.width()
w = initial_width
#ontehuas
i = 0
depths.append(final_depth)
while w < target_w:
w*=2
i+=1
depths.append(final_depth + i*config.depth_increase)
depths = depths[1:]
depths.reverse()
return depths
def build(self, net):
gan = self.gan
ops = self.ops
config = self.config
nets = []
activation = ops.lookup(config.activation)
final_activation = ops.lookup(config.final_activation)
block = config.block or standard_block
if config.skip_linear:
net = self.layer_filter(net)
if config.concat_linear:
size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
net2 = tf.reshape(net, [ops.shape(net)[0], -1])
net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
net2 = ops.linear(net2, size)
net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
net2 = self.layer_regularizer(net2)
net2 = config.activation(net2)
net = tf.concat([net, net2], axis=3)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
for i in range(config.extra_layers or 0):
net = self.layer_regularizer(net)
net = activation(net)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
else:
net = ops.reshape(net, [ops.shape(net)[0], -1])
primes = config.initial_dimensions or [4, 4]
depths = self.depths(primes[0])
initial_depth = depths[0]
new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
net = ops.linear(net, initial_depth*primes[0]*primes[1])
net = ops.reshape(net, new_shape)
shape = ops.shape(net)
depths = self.depths(initial_width = shape[1])
print("[generator] Initial depth", shape)
if config.relation_layer:
net = self.layer_regularizer(net)
net = activation(net)
net = self.relation_layer(net)
print("[generator] relational layer", net)
else:
pass
depth_reduction = np.float32(config.depth_reduction)
shape = ops.shape(net)
net = self.layer_filter(net)
for i, depth in enumerate(depths[1:]):
s = ops.shape(net)
resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
net = self.layer_regularizer(net)
net = activation(net)
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = block(self, net, depth, filter=3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, depth)
size = resize[0]*resize[1]*depth
print("[generator] layer", net, size)
net = self.layer_regularizer(net)
net = activation(net)
resize = [gan.height(), gan.width()]
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = self.layer_filter(net)
net = block(self, net, gan.channels(), filter=config.final_filter or 3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())
if final_activation:
net = self.layer_regularizer(net)
net = final_activation(net)
self.sample = net
return self.sample
Classes
class ResizeConvGenerator
GANComponents are pluggable pieces within a GAN.
GAN objects are also GANComponents.
class ResizeConvGenerator(BaseGenerator):
def required(self):
return "final_depth activation depth_increase".split()
def depths(self, initial_width=4):
gan = self.gan
ops = self.ops
config = self.config
final_depth = config.final_depth-config.depth_increase
depths = []
target_w = gan.width()
w = initial_width
#ontehuas
i = 0
depths.append(final_depth)
while w < target_w:
w*=2
i+=1
depths.append(final_depth + i*config.depth_increase)
depths = depths[1:]
depths.reverse()
return depths
def build(self, net):
gan = self.gan
ops = self.ops
config = self.config
nets = []
activation = ops.lookup(config.activation)
final_activation = ops.lookup(config.final_activation)
block = config.block or standard_block
if config.skip_linear:
net = self.layer_filter(net)
if config.concat_linear:
size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
net2 = tf.reshape(net, [ops.shape(net)[0], -1])
net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
net2 = ops.linear(net2, size)
net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
net2 = self.layer_regularizer(net2)
net2 = config.activation(net2)
net = tf.concat([net, net2], axis=3)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
for i in range(config.extra_layers or 0):
net = self.layer_regularizer(net)
net = activation(net)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
else:
net = ops.reshape(net, [ops.shape(net)[0], -1])
primes = config.initial_dimensions or [4, 4]
depths = self.depths(primes[0])
initial_depth = depths[0]
new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
net = ops.linear(net, initial_depth*primes[0]*primes[1])
net = ops.reshape(net, new_shape)
shape = ops.shape(net)
depths = self.depths(initial_width = shape[1])
print("[generator] Initial depth", shape)
if config.relation_layer:
net = self.layer_regularizer(net)
net = activation(net)
net = self.relation_layer(net)
print("[generator] relational layer", net)
else:
pass
depth_reduction = np.float32(config.depth_reduction)
shape = ops.shape(net)
net = self.layer_filter(net)
for i, depth in enumerate(depths[1:]):
s = ops.shape(net)
resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
net = self.layer_regularizer(net)
net = activation(net)
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = block(self, net, depth, filter=3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, depth)
size = resize[0]*resize[1]*depth
print("[generator] layer", net, size)
net = self.layer_regularizer(net)
net = activation(net)
resize = [gan.height(), gan.width()]
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = self.layer_filter(net)
net = block(self, net, gan.channels(), filter=config.final_filter or 3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())
if final_activation:
net = self.layer_regularizer(net)
net = final_activation(net)
self.sample = net
return self.sample
Ancestors (in MRO)
- ResizeConvGenerator
- hypergan.generators.base_generator.BaseGenerator
- hypergan.gan_component.GANComponent
- builtins.object
Static methods
def __init__(
self, gan, config)
Initializes a gan component based on a gan
and a config
dictionary.
Different components require different config variables.
A ValidationException
is raised if the GAN component configuration fails to validate.
def __init__(self, gan, config):
"""
Initializes a gan component based on a `gan` and a `config` dictionary.
Different components require different config variables.
A `ValidationException` is raised if the GAN component configuration fails to validate.
"""
self.gan = gan
self.config = hc.Config(config)
errors = self.validate()
if errors != []:
raise ValidationException(self.__class__.__name__+": " +"\n".join(errors))
self.create_ops(config)
def biases(
self)
Biases of the GAN component.
def biases(self):
"""
Biases of the GAN component.
"""
return self.ops.biases
def build(
self, net)
def build(self, net):
gan = self.gan
ops = self.ops
config = self.config
nets = []
activation = ops.lookup(config.activation)
final_activation = ops.lookup(config.final_activation)
block = config.block or standard_block
if config.skip_linear:
net = self.layer_filter(net)
if config.concat_linear:
size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
net2 = tf.reshape(net, [ops.shape(net)[0], -1])
net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
net2 = ops.linear(net2, size)
net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
net2 = self.layer_regularizer(net2)
net2 = config.activation(net2)
net = tf.concat([net, net2], axis=3)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
for i in range(config.extra_layers or 0):
net = self.layer_regularizer(net)
net = activation(net)
net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
else:
net = ops.reshape(net, [ops.shape(net)[0], -1])
primes = config.initial_dimensions or [4, 4]
depths = self.depths(primes[0])
initial_depth = depths[0]
new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
net = ops.linear(net, initial_depth*primes[0]*primes[1])
net = ops.reshape(net, new_shape)
shape = ops.shape(net)
depths = self.depths(initial_width = shape[1])
print("[generator] Initial depth", shape)
if config.relation_layer:
net = self.layer_regularizer(net)
net = activation(net)
net = self.relation_layer(net)
print("[generator] relational layer", net)
else:
pass
depth_reduction = np.float32(config.depth_reduction)
shape = ops.shape(net)
net = self.layer_filter(net)
for i, depth in enumerate(depths[1:]):
s = ops.shape(net)
resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
net = self.layer_regularizer(net)
net = activation(net)
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = block(self, net, depth, filter=3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, depth)
size = resize[0]*resize[1]*depth
print("[generator] layer", net, size)
net = self.layer_regularizer(net)
net = activation(net)
resize = [gan.height(), gan.width()]
if block != 'deconv':
net = ops.resize_images(net, resize, config.resize_image_type or 1)
net = self.layer_filter(net)
net = block(self, net, gan.channels(), filter=config.final_filter or 3)
else:
net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())
if final_activation:
net = self.layer_regularizer(net)
net = final_activation(net)
self.sample = net
return self.sample
def create(
self, sample=None)
def create(self, sample=None):
gan = self.gan
ops = self.ops
if sample is None:
sample = gan.encoder.sample
return self.build(sample)
def create_ops(
self, config)
Create the ops object as self.ops
. Also looks up config
def create_ops(self, config):
"""
Create the ops object as `self.ops`. Also looks up config
"""
if self.gan is None:
return
if self.gan.ops_backend is None:
return
self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
self.config = self.gan.ops.lookup(config)
def depths(
self, initial_width=4)
def depths(self, initial_width=4):
gan = self.gan
ops = self.ops
config = self.config
final_depth = config.final_depth-config.depth_increase
depths = []
target_w = gan.width()
w = initial_width
#ontehuas
i = 0
depths.append(final_depth)
while w < target_w:
w*=2
i+=1
depths.append(final_depth + i*config.depth_increase)
depths = depths[1:]
depths.reverse()
return depths
def fully_connected_from_list(
self, nets)
def fully_connected_from_list(self, nets):
results = []
ops = self.ops
for net, net2 in nets:
net = ops.concat([net, net2], axis=3)
shape = ops.shape(net)
bs = shape[0]
net = ops.reshape(net, [bs, -1])
features = ops.shape(net)[1]
net = ops.linear(net, features)
#net = self.layer_regularizer(net)
net = ops.lookup('lrelu')(net)
#net = ops.linear(net, features)
net = ops.reshape(net, shape)
results.append(net)
return results
def layer_filter(
self, net)
def layer_filter(self, net):
ops = self.ops
gan = self.gan
config = self.config
if config.layer_filter:
print("[base generator] applying layer filter", config['layer_filter'])
fltr = config.layer_filter(gan, self.config, net)
if fltr is not None:
net = ops.concat(axis=3, values=[net, fltr])
return net
def layer_regularizer(
self, net)
def layer_regularizer(self, net):
symbol = self.config.layer_regularizer
op = self.gan.ops.lookup(symbol)
if op:
net = op(self, net)
return net
def permute(
self, nets, k)
def permute(self, nets, k):
return list(itertools.permutations(nets, k))
def relation_layer(
self, net)
def relation_layer(self, net):
ops = self.ops
#hack
shape = ops.shape(net)
input_size = shape[1]*shape[2]*shape[3]
netlist = self.split_by_width_height(net)
permutations = self.permute(netlist, 2)
permutations = self.fully_connected_from_list(permutations)
net = ops.concat(permutations, axis=3)
#hack
bs = ops.shape(net)[0]
net = ops.reshape(net, [bs, -1])
net = ops.linear(net, input_size)
net = ops.reshape(net, shape)
return net
def required(
self)
Return a list of required config strings and a ValidationException
will be thrown if any are missing.
Example:
python
class MyComponent(GANComponent):
def required(self):
"learn rate is required"
["learn_rate"]
def required(self):
return "final_depth activation depth_increase".split()
def reuse(
self, net)
def reuse(self, net):
self.ops.reuse()
net = self.build(net)
self.ops.stop_reuse()
return net
def split_batch(
self, net, count=2)
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
def split_batch(self, net, count=2):
"""
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
"""
ops = self.ops or self.gan.ops
s = ops.shape(net)
bs = s[0]
nets = []
net = ops.reshape(net, [bs, -1])
start = [0 for x in ops.shape(net)]
for i in range(count):
size = [bs//count] + [x for x in ops.shape(net)[1:]]
nets.append(ops.slice(net, start, size))
start[0] += bs//count
return nets
def split_by_width_height(
self, net)
def split_by_width_height(self, net):
elems = []
ops = self.gan.ops
shape = ops.shape(net)
bs = shape[0]
height = shape[1]
width = shape[2]
for i in range(width):
for j in range(height):
elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
return elems
def validate(
self)
Validates a GANComponent. Return an array of error messages. Empty array []
means success.
def validate(self):
"""
Validates a GANComponent. Return an array of error messages. Empty array `[]` means success.
"""
errors = []
required = self.required()
for argument in required:
if(self.config.__getattr__(argument) == None):
errors.append("`"+argument+"` required")
if(self.gan is None):
errors.append("GANComponent constructed without GAN")
return errors
def variables(
self)
All variables associated with this component.
def variables(self):
"""
All variables associated with this component.
"""
return self.ops.variables()
def weights(
self)
The weights of the GAN component.
def weights(self):
"""
The weights of the GAN component.
"""
return self.ops.weights