Top

hypergan.trainers.multi_step_trainer module

import tensorflow as tf
import numpy as np
import hyperchamber as hc
import inspect

from hypergan.trainers.base_trainer import BaseTrainer

TINY = 1e-12

class MultiStepTrainer(BaseTrainer):
    def __init__(self, gan, config, losses=[], var_lists=[], metrics=None):
        BaseTrainer.__init__(self, gan, config)
        self.losses = losses
        self.var_lists = var_lists
        self.metrics = metrics or [None for i in self.losses]

    def _create(self):
        gan = self.gan
        config = self.config
        losses = self.losses
        g_lr = config.g_learn_rate
        d_lr = config.d_learn_rate

        optimizers = []
        self.d_lr = tf.Variable(d_lr, dtype=tf.float32)
        self.g_lr = tf.Variable(g_lr, dtype=tf.float32)
        for i, _ in enumerate(losses):
            loss = losses[i][1]
            var_list = self.var_lists[i]
            is_generator = losses[i][0] == 'generator'

            if is_generator:
                optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, var_list, loss)
            else:
                optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, var_list, loss)
            optimizers.append(optimizer)

        self.optimizers = optimizers


        if config.d_clipped_weights:
            self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights))  for d in d_vars]
        else:
            self.clip = []

        return None

    def _step(self, feed_dict):
        gan = self.gan
        sess = gan.session
        config = self.config
        losses = self.losses
        metrics = self.metrics

        for i, _ in enumerate(losses):
            loss = losses[i]
            optimizer = self.optimizers[i]
            metric = metrics[i]
            if(metric):
                metric_values = sess.run([optimizer] + self.output_variables(metric), feed_dict)[1:]

                if self.current_step % 100 == 0:
                    print("loss " + str(i) + "  "+ self.output_string(metric) % tuple([self.current_step] + metric_values))
            else:
                _ = sess.run(optimizer, feed_dict)

Module variables

var TINY

Classes

class MultiStepTrainer

GANComponents are pluggable pieces within a GAN.

GAN objects are also GANComponents.

class MultiStepTrainer(BaseTrainer):
    def __init__(self, gan, config, losses=[], var_lists=[], metrics=None):
        BaseTrainer.__init__(self, gan, config)
        self.losses = losses
        self.var_lists = var_lists
        self.metrics = metrics or [None for i in self.losses]

    def _create(self):
        gan = self.gan
        config = self.config
        losses = self.losses
        g_lr = config.g_learn_rate
        d_lr = config.d_learn_rate

        optimizers = []
        self.d_lr = tf.Variable(d_lr, dtype=tf.float32)
        self.g_lr = tf.Variable(g_lr, dtype=tf.float32)
        for i, _ in enumerate(losses):
            loss = losses[i][1]
            var_list = self.var_lists[i]
            is_generator = losses[i][0] == 'generator'

            if is_generator:
                optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, var_list, loss)
            else:
                optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, var_list, loss)
            optimizers.append(optimizer)

        self.optimizers = optimizers


        if config.d_clipped_weights:
            self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights))  for d in d_vars]
        else:
            self.clip = []

        return None

    def _step(self, feed_dict):
        gan = self.gan
        sess = gan.session
        config = self.config
        losses = self.losses
        metrics = self.metrics

        for i, _ in enumerate(losses):
            loss = losses[i]
            optimizer = self.optimizers[i]
            metric = metrics[i]
            if(metric):
                metric_values = sess.run([optimizer] + self.output_variables(metric), feed_dict)[1:]

                if self.current_step % 100 == 0:
                    print("loss " + str(i) + "  "+ self.output_string(metric) % tuple([self.current_step] + metric_values))
            else:
                _ = sess.run(optimizer, feed_dict)

Ancestors (in MRO)

  • MultiStepTrainer
  • hypergan.trainers.base_trainer.BaseTrainer
  • hypergan.gan_component.GANComponent
  • builtins.object

Static methods

def __init__(

self, gan, config, losses=[], var_lists=[], metrics=None)

Initializes a gan component based on a gan and a config dictionary.

Different components require different config variables.

A ValidationException is raised if the GAN component configuration fails to validate.

def __init__(self, gan, config, losses=[], var_lists=[], metrics=None):
    BaseTrainer.__init__(self, gan, config)
    self.losses = losses
    self.var_lists = var_lists
    self.metrics = metrics or [None for i in self.losses]

def biases(

self)

Biases of the GAN component.

def biases(self):
    """
        Biases of the GAN component.
    """
    return self.ops.biases

def build_optimizer(

self, config, prefix, trainer_config, learning_rate, var_list, loss)

def build_optimizer(self, config, prefix, trainer_config, learning_rate, var_list, loss):
    with tf.variable_scope(prefix):
        defn = {k[2:]: v for k, v in config.items() if k[2:] in inspect.getargspec(trainer_config).args and k.startswith(prefix)}
        optimizer = trainer_config(learning_rate, **defn)
        if(config.clipped_gradients):
            apply_gradients = self.capped_optimizer(optimizer, config.clipped_gradients, loss, var_list)
        else:
            apply_gradients = optimizer.minimize(loss, var_list=var_list)
    return apply_gradients

def capped_optimizer(

optimizer, cap, loss, var_list)

def capped_optimizer(optimizer, cap, loss, var_list):
    gvs = optimizer.compute_gradients(loss, var_list=var_list)
    def create_cap(grad,var):
        if(grad == None) :
            print("Warning: No gradient for variable ",var.name)
            return None
        return (tf.clip_by_value(grad, -cap, cap), var)
    capped_gvs = [create_cap(grad,var) for grad, var in gvs]
    capped_gvs = [x for x in capped_gvs if x != None]
    return optimizer.apply_gradients(capped_gvs)

def create(

self)

def create(self):
    self.create_called = True
    return self._create()

def create_ops(

self, config)

Create the ops object as self.ops. Also looks up config

def create_ops(self, config):
    """
    Create the ops object as `self.ops`.  Also looks up config
    """
    if self.gan is None:
        return
    if self.gan.ops_backend is None:
        return
    self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
    self.config = self.gan.ops.lookup(config)

def fully_connected_from_list(

self, nets)

def fully_connected_from_list(self, nets):
    results = []
    ops = self.ops
    for net, net2 in nets:
        net = ops.concat([net, net2], axis=3)
        shape = ops.shape(net)
        bs = shape[0]
        net = ops.reshape(net, [bs, -1])
        features = ops.shape(net)[1]
        net = ops.linear(net, features)
        #net = self.layer_regularizer(net)
        net = ops.lookup('lrelu')(net)
        #net = ops.linear(net, features)
        net = ops.reshape(net, shape)
        results.append(net)
    return results

def layer_regularizer(

self, net)

def layer_regularizer(self, net):
    symbol = self.config.layer_regularizer
    op = self.gan.ops.lookup(symbol)
    if op:
        net = op(self, net)
    return net

def output_string(

self, metrics)

def output_string(self, metrics):
    output = "\%2d: " 
    for name in sorted(metrics.keys()):
        output += " " + name
        output += " %.2f"
    return output

def output_variables(

self, metrics)

def output_variables(self, metrics):
    gan = self.gan
    sess = gan.session
    return [metrics[k] for k in sorted(metrics.keys())]

def permute(

self, nets, k)

def permute(self, nets, k):
    return list(itertools.permutations(nets, k))

def relation_layer(

self, net)

def relation_layer(self, net):
    ops = self.ops
    #hack
    shape = ops.shape(net)
    input_size = shape[1]*shape[2]*shape[3]
    netlist = self.split_by_width_height(net)
    permutations = self.permute(netlist, 2)
    permutations = self.fully_connected_from_list(permutations)
    net = ops.concat(permutations, axis=3)
    #hack
    bs = ops.shape(net)[0]
    net = ops.reshape(net, [bs, -1])
    net = ops.linear(net, input_size)
    net = ops.reshape(net, shape)
    return net

def required(

self)

Return a list of required config strings and a ValidationException will be thrown if any are missing.

Example: python class MyComponent(GANComponent): def required(self): "learn rate is required" ["learn_rate"]

def required(self):
    return "d_trainer g_trainer d_learn_rate g_learn_rate".split()

def reuse(

self, net)

def reuse(self, net):
    self.ops.reuse()
    net = self.build(net)
    self.ops.stop_reuse()
    return net

def split_batch(

self, net, count=2)

Discriminators return stacked results (on axis 0).

This splits the results. Returns [d_real, d_fake]

def split_batch(self, net, count=2):
    """ 
    Discriminators return stacked results (on axis 0).  
    
    This splits the results.  Returns [d_real, d_fake]
    """
    ops = self.ops or self.gan.ops
    s = ops.shape(net)
    bs = s[0]
    nets = []
    net = ops.reshape(net, [bs, -1])
    start = [0 for x in ops.shape(net)]
    for i in range(count):
        size = [bs//count] + [x for x in ops.shape(net)[1:]]
        nets.append(ops.slice(net, start, size))
        start[0] += bs//count
    return nets

def split_by_width_height(

self, net)

def split_by_width_height(self, net):
    elems = []
    ops = self.gan.ops
    shape = ops.shape(net)
    bs = shape[0]
    height = shape[1]
    width = shape[2]
    for i in range(width):
        for j in range(height):
            elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
    return elems

def step(

self, feed_dict={})

def step(self, feed_dict={}):
    if not self.create_called:
        self.create()
    step = self._step(feed_dict)
    self.current_step += 1
    return step

def validate(

self)

Validates a GANComponent. Return an array of error messages. Empty array [] means success.

def validate(self):
    """
    Validates a GANComponent.  Return an array of error messages. Empty array `[]` means success.
    """
    errors = []
    required = self.required()
    for argument in required:
        if(self.config.__getattr__(argument) == None):
            errors.append("`"+argument+"` required")
    if(self.gan is None):
        errors.append("GANComponent constructed without GAN")
    return errors

def variables(

self)

All variables associated with this component.

def variables(self):
    """
        All variables associated with this component.
    """
    return self.ops.variables()

def weights(

self)

The weights of the GAN component.

def weights(self):
    """
        The weights of the GAN component.
    """
    return self.ops.weights

Instance variables

var losses

var metrics

var var_lists