hypergan.trainers.alternating_trainer module
import tensorflow as tf
import numpy as np
import hyperchamber as hc
import inspect
from hypergan.trainers.base_trainer import BaseTrainer
TINY = 1e-12
class AlternatingTrainer(BaseTrainer):
def _create(self):
gan = self.gan
config = self.config
g_lr = config.g_learn_rate
d_lr = config.d_learn_rate
d_vars = self.d_vars or gan.discriminator.variables()
g_vars = self.g_vars or (gan.encoder.variables() + gan.generator.variables())
loss = self.loss or gan.loss
d_loss, g_loss = loss.sample
self.d_log = -tf.log(tf.abs(d_loss+TINY))
self.d_lr = tf.Variable(d_lr, dtype=tf.float32)
self.g_lr = tf.Variable(g_lr, dtype=tf.float32)
g_optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, g_vars, g_loss)
d_optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, d_vars, d_loss)
self.g_loss = g_loss
self.d_loss = d_loss
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
if config.d_clipped_weights:
self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights)) for d in d_vars]
else:
self.clip = []
if config.anneal_learning_rate:
anneal_rate = config.anneal_rate or 0.5
anneal_lower_bound = config.anneal_lower_bound or 2e-5
self.anneal = [
tf.assign(self.d_lr, tf.maximum(self.d_lr * anneal_rate, anneal_lower_bound)),
tf.assign(self.g_lr, tf.maximum(self.g_lr * anneal_rate, anneal_lower_bound))
]
return g_optimizer, d_optimizer
def _step(self, feed_dict):
gan = self.gan
sess = gan.session
config = self.config
loss = self.loss or gan.loss
metrics = loss.metrics
d_loss, g_loss = loss.sample
for i in range(config.d_update_steps or 1):
sess.run([self.d_optimizer] + self.clip, feed_dict)
metric_values = sess.run([self.g_optimizer] + self.output_variables(metrics), feed_dict)[1:]
if self.current_step % 100 == 0:
print(self.output_string(metrics) % tuple([self.current_step] + metric_values))
anneal_every = config.anneal_every or 100000
if config.anneal_learning_rate and self.current_step > 0 and self.current_step % anneal_every == 0:
dlr, glr = sess.run(self.anneal)
print("Lowering the learning rate to d:" + str(dlr) + ", g:" + str(glr))
Module variables
var TINY
Classes
class AlternatingTrainer
GANComponents are pluggable pieces within a GAN.
GAN objects are also GANComponents.
class AlternatingTrainer(BaseTrainer):
def _create(self):
gan = self.gan
config = self.config
g_lr = config.g_learn_rate
d_lr = config.d_learn_rate
d_vars = self.d_vars or gan.discriminator.variables()
g_vars = self.g_vars or (gan.encoder.variables() + gan.generator.variables())
loss = self.loss or gan.loss
d_loss, g_loss = loss.sample
self.d_log = -tf.log(tf.abs(d_loss+TINY))
self.d_lr = tf.Variable(d_lr, dtype=tf.float32)
self.g_lr = tf.Variable(g_lr, dtype=tf.float32)
g_optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, g_vars, g_loss)
d_optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, d_vars, d_loss)
self.g_loss = g_loss
self.d_loss = d_loss
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
if config.d_clipped_weights:
self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights)) for d in d_vars]
else:
self.clip = []
if config.anneal_learning_rate:
anneal_rate = config.anneal_rate or 0.5
anneal_lower_bound = config.anneal_lower_bound or 2e-5
self.anneal = [
tf.assign(self.d_lr, tf.maximum(self.d_lr * anneal_rate, anneal_lower_bound)),
tf.assign(self.g_lr, tf.maximum(self.g_lr * anneal_rate, anneal_lower_bound))
]
return g_optimizer, d_optimizer
def _step(self, feed_dict):
gan = self.gan
sess = gan.session
config = self.config
loss = self.loss or gan.loss
metrics = loss.metrics
d_loss, g_loss = loss.sample
for i in range(config.d_update_steps or 1):
sess.run([self.d_optimizer] + self.clip, feed_dict)
metric_values = sess.run([self.g_optimizer] + self.output_variables(metrics), feed_dict)[1:]
if self.current_step % 100 == 0:
print(self.output_string(metrics) % tuple([self.current_step] + metric_values))
anneal_every = config.anneal_every or 100000
if config.anneal_learning_rate and self.current_step > 0 and self.current_step % anneal_every == 0:
dlr, glr = sess.run(self.anneal)
print("Lowering the learning rate to d:" + str(dlr) + ", g:" + str(glr))
Ancestors (in MRO)
- AlternatingTrainer
- hypergan.trainers.base_trainer.BaseTrainer
- hypergan.gan_component.GANComponent
- builtins.object
Static methods
def __init__(
self, gan, config, d_vars=None, g_vars=None, loss=None)
Initializes a gan component based on a gan
and a config
dictionary.
Different components require different config variables.
A ValidationException
is raised if the GAN component configuration fails to validate.
def __init__(self, gan, config, d_vars=None, g_vars=None, loss=None):
GANComponent.__init__(self, gan, config)
self.create_called = False
self.current_step = 0
self.g_vars = g_vars
self.d_vars = d_vars
self.loss = loss
def biases(
self)
Biases of the GAN component.
def biases(self):
"""
Biases of the GAN component.
"""
return self.ops.biases
def build_optimizer(
self, config, prefix, trainer_config, learning_rate, var_list, loss)
def build_optimizer(self, config, prefix, trainer_config, learning_rate, var_list, loss):
with tf.variable_scope(prefix):
defn = {k[2:]: v for k, v in config.items() if k[2:] in inspect.getargspec(trainer_config).args and k.startswith(prefix)}
optimizer = trainer_config(learning_rate, **defn)
if(config.clipped_gradients):
apply_gradients = self.capped_optimizer(optimizer, config.clipped_gradients, loss, var_list)
else:
apply_gradients = optimizer.minimize(loss, var_list=var_list)
return apply_gradients
def capped_optimizer(
optimizer, cap, loss, var_list)
def capped_optimizer(optimizer, cap, loss, var_list):
gvs = optimizer.compute_gradients(loss, var_list=var_list)
def create_cap(grad,var):
if(grad == None) :
print("Warning: No gradient for variable ",var.name)
return None
return (tf.clip_by_value(grad, -cap, cap), var)
capped_gvs = [create_cap(grad,var) for grad, var in gvs]
capped_gvs = [x for x in capped_gvs if x != None]
return optimizer.apply_gradients(capped_gvs)
def create(
self)
def create(self):
self.create_called = True
return self._create()
def create_ops(
self, config)
Create the ops object as self.ops
. Also looks up config
def create_ops(self, config):
"""
Create the ops object as `self.ops`. Also looks up config
"""
if self.gan is None:
return
if self.gan.ops_backend is None:
return
self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
self.config = self.gan.ops.lookup(config)
def fully_connected_from_list(
self, nets)
def fully_connected_from_list(self, nets):
results = []
ops = self.ops
for net, net2 in nets:
net = ops.concat([net, net2], axis=3)
shape = ops.shape(net)
bs = shape[0]
net = ops.reshape(net, [bs, -1])
features = ops.shape(net)[1]
net = ops.linear(net, features)
#net = self.layer_regularizer(net)
net = ops.lookup('lrelu')(net)
#net = ops.linear(net, features)
net = ops.reshape(net, shape)
results.append(net)
return results
def layer_regularizer(
self, net)
def layer_regularizer(self, net):
symbol = self.config.layer_regularizer
op = self.gan.ops.lookup(symbol)
if op:
net = op(self, net)
return net
def output_string(
self, metrics)
def output_string(self, metrics):
output = "\%2d: "
for name in sorted(metrics.keys()):
output += " " + name
output += " %.2f"
return output
def output_variables(
self, metrics)
def output_variables(self, metrics):
gan = self.gan
sess = gan.session
return [metrics[k] for k in sorted(metrics.keys())]
def permute(
self, nets, k)
def permute(self, nets, k):
return list(itertools.permutations(nets, k))
def relation_layer(
self, net)
def relation_layer(self, net):
ops = self.ops
#hack
shape = ops.shape(net)
input_size = shape[1]*shape[2]*shape[3]
netlist = self.split_by_width_height(net)
permutations = self.permute(netlist, 2)
permutations = self.fully_connected_from_list(permutations)
net = ops.concat(permutations, axis=3)
#hack
bs = ops.shape(net)[0]
net = ops.reshape(net, [bs, -1])
net = ops.linear(net, input_size)
net = ops.reshape(net, shape)
return net
def required(
self)
Return a list of required config strings and a ValidationException
will be thrown if any are missing.
Example:
python
class MyComponent(GANComponent):
def required(self):
"learn rate is required"
["learn_rate"]
def required(self):
return "d_trainer g_trainer d_learn_rate g_learn_rate".split()
def reuse(
self, net)
def reuse(self, net):
self.ops.reuse()
net = self.build(net)
self.ops.stop_reuse()
return net
def split_batch(
self, net, count=2)
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
def split_batch(self, net, count=2):
"""
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
"""
ops = self.ops or self.gan.ops
s = ops.shape(net)
bs = s[0]
nets = []
net = ops.reshape(net, [bs, -1])
start = [0 for x in ops.shape(net)]
for i in range(count):
size = [bs//count] + [x for x in ops.shape(net)[1:]]
nets.append(ops.slice(net, start, size))
start[0] += bs//count
return nets
def split_by_width_height(
self, net)
def split_by_width_height(self, net):
elems = []
ops = self.gan.ops
shape = ops.shape(net)
bs = shape[0]
height = shape[1]
width = shape[2]
for i in range(width):
for j in range(height):
elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
return elems
def step(
self, feed_dict={})
def step(self, feed_dict={}):
if not self.create_called:
self.create()
step = self._step(feed_dict)
self.current_step += 1
return step
def validate(
self)
Validates a GANComponent. Return an array of error messages. Empty array []
means success.
def validate(self):
"""
Validates a GANComponent. Return an array of error messages. Empty array `[]` means success.
"""
errors = []
required = self.required()
for argument in required:
if(self.config.__getattr__(argument) == None):
errors.append("`"+argument+"` required")
if(self.gan is None):
errors.append("GANComponent constructed without GAN")
return errors
def variables(
self)
All variables associated with this component.
def variables(self):
"""
All variables associated with this component.
"""
return self.ops.variables()
def weights(
self)
The weights of the GAN component.
def weights(self):
"""
The weights of the GAN component.
"""
return self.ops.weights