Top

hypergan.losses.lamb_gan_loss module

import tensorflow as tf
import hyperchamber as hc
import numpy as np

from hypergan.losses.standard_loss import StandardLoss
from hypergan.losses.least_squares_loss import LeastSquaresLoss
from hypergan.losses.wasserstein_loss import WassersteinLoss
from hypergan.losses.base_loss import BaseLoss


class LambGanLoss(BaseLoss):

    def required(self):
        return "label_smooth".split()

    def _create(self, d_real, d_fake):
        config = self.config

        alpha = config.alpha
        beta = config.beta
        wgan_loss_d, wgan_loss_g = WassersteinLoss._create(self, d_real, d_fake)
        lsgan_loss_d, lsgan_loss_g = LeastSquaresLoss._create(self, d_real, d_fake)
        standard_loss_d, standard_loss_g = StandardLoss._create(self, d_real, d_fake)

        total = min(alpha + beta,1)

        d_loss = wgan_loss_d*alpha + lsgan_loss_d*beta + (1-total)*standard_loss_d
        g_loss = wgan_loss_g*alpha + lsgan_loss_g*beta + (1-total)*standard_loss_g

        return [d_loss, g_loss]

Classes

class LambGanLoss

GANComponents are pluggable pieces within a GAN.

GAN objects are also GANComponents.

class LambGanLoss(BaseLoss):

    def required(self):
        return "label_smooth".split()

    def _create(self, d_real, d_fake):
        config = self.config

        alpha = config.alpha
        beta = config.beta
        wgan_loss_d, wgan_loss_g = WassersteinLoss._create(self, d_real, d_fake)
        lsgan_loss_d, lsgan_loss_g = LeastSquaresLoss._create(self, d_real, d_fake)
        standard_loss_d, standard_loss_g = StandardLoss._create(self, d_real, d_fake)

        total = min(alpha + beta,1)

        d_loss = wgan_loss_d*alpha + lsgan_loss_d*beta + (1-total)*standard_loss_d
        g_loss = wgan_loss_g*alpha + lsgan_loss_g*beta + (1-total)*standard_loss_g

        return [d_loss, g_loss]

Ancestors (in MRO)

  • LambGanLoss
  • hypergan.losses.base_loss.BaseLoss
  • hypergan.gan_component.GANComponent
  • builtins.object

Static methods

def __init__(

self, gan, config, discriminator=None, generator=None)

Initializes a gan component based on a gan and a config dictionary.

Different components require different config variables.

A ValidationException is raised if the GAN component configuration fails to validate.

def __init__(self, gan, config, discriminator=None, generator=None):
    GANComponent.__init__(self, gan, config)
    self.metrics = {}
    self.sample = None
    self.ops = None
    self.discriminator = discriminator
    self.generator = generator

def biases(

self)

Biases of the GAN component.

def biases(self):
    """
        Biases of the GAN component.
    """
    return self.ops.biases

def create(

self, split=2)

def create(self, split=2):
    gan = self.gan
    config = self.config
    ops = self.gan.ops
    
    if self.discriminator is None:
        net = gan.discriminator.sample
    else:
        net = self.discriminator.sample
    if split == 2:
        d_real, d_fake = self.split_batch(net, split)
        d_loss, g_loss = self._create(d_real, d_fake)
    elif split == 3:
        d_real, d_fake, d_fake2 = self.split_batch(net, split)
        d_loss, g_loss = self._create(d_real, d_fake)
        d_loss2, g_loss2 = self._create(d_real, d_fake2)
        g_loss += g_loss2
        d_loss += d_loss2
        #does this double the signal of d_real?
    if d_loss is not None:
        d_loss = ops.squash(d_loss, tf.reduce_mean) #linear doesn't work with this, so we cant pass config.reduce
        self.metrics['d_loss'] = d_loss
        if config.minibatch:
            d_loss += self.minibatch(net)
        if config.gradient_penalty:
            gp = self.gradient_penalty()
            self.metrics['gradient_penalty'] = gp
            print("Gradient penalty applied")
            d_loss += gp
    if g_loss is not None:
        g_loss = ops.squash(g_loss, tf.reduce_mean)
        self.metrics['g_loss'] = g_loss
    self.metrics = self.metrics or sample_metrics
    self.sample = [d_loss, g_loss]
    self.d_loss = d_loss
    self.g_loss = g_loss
    return self.sample

def create_ops(

self, config)

Create the ops object as self.ops. Also looks up config

def create_ops(self, config):
    """
    Create the ops object as `self.ops`.  Also looks up config
    """
    if self.gan is None:
        return
    if self.gan.ops_backend is None:
        return
    self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
    self.config = self.gan.ops.lookup(config)

def fully_connected_from_list(

self, nets)

def fully_connected_from_list(self, nets):
    results = []
    ops = self.ops
    for net, net2 in nets:
        net = ops.concat([net, net2], axis=3)
        shape = ops.shape(net)
        bs = shape[0]
        net = ops.reshape(net, [bs, -1])
        features = ops.shape(net)[1]
        net = ops.linear(net, features)
        #net = self.layer_regularizer(net)
        net = ops.lookup('lrelu')(net)
        #net = ops.linear(net, features)
        net = ops.reshape(net, shape)
        results.append(net)
    return results

def gradient_penalty(

self)

def gradient_penalty(self):
    config = self.config
    gan = self.gan
    gradient_penalty = config.gradient_penalty
    if has_attr(gan.inputs, 'gradient_penalty_label'):
        x = gan.inputs.gradient_penalty_label
    else:
        x = gan.inputs.x
    generator = self.generator or gan.generator
    g = generator.sample
    discriminator = self.discriminator or gan.discriminator
    shape = [1 for t in g.get_shape()]
    shape[0] = gan.batch_size()
    uniform_noise = tf.random_uniform(shape=shape,minval=0.,maxval=1.)
    print("[gradient penalty] applying x:", x, "g:", g, "noise:", uniform_noise)
    interpolates = x + uniform_noise * (g - x)
    reused_d = discriminator.reuse(interpolates)
    gradients = tf.gradients(reused_d, [interpolates])[0]
    penalty = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
    penalty = tf.reduce_mean(tf.square(penalty - 1.))
    return float(gradient_penalty) * penalty

def layer_regularizer(

self, net)

def layer_regularizer(self, net):
    symbol = self.config.layer_regularizer
    op = self.gan.ops.lookup(symbol)
    if op:
        net = op(self, net)
    return net

def minibatch(

self, net)

def minibatch(self, net):
    discriminator = self.discriminator or self.gan.discriminator
    ops = discriminator.ops
    config = self.config
    batch_size = ops.shape(net)[0]
    single_batch_size = batch_size//2
    n_kernels = config.minibatch_kernels or 300
    dim_per_kernel = config.dim_per_kernel or 50
    print("[discriminator] minibatch from", net, "to", n_kernels*dim_per_kernel)
    x = ops.linear(net, n_kernels * dim_per_kernel)
    activation = tf.reshape(x, (batch_size, n_kernels, dim_per_kernel))
    big = np.zeros((batch_size, batch_size))
    big += np.eye(batch_size)
    big = tf.expand_dims(big, 1)
    big = tf.cast(big,dtype=ops.dtype)
    abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(activation,3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)), 2)
    mask = 1. - big
    masked = tf.exp(-abs_dif) * mask
    def half(tens, second):
        m, n, _ = tens.get_shape()
        m = int(m)
        n = int(n)
        return tf.slice(tens, [0, 0, second * single_batch_size], [m, n, single_batch_size])
    f1 = tf.reduce_sum(half(masked, 0), 2) / tf.reduce_sum(half(mask, 0))
    f2 = tf.reduce_sum(half(masked, 1), 2) / tf.reduce_sum(half(mask, 1))
    return ops.squash(ops.concat([f1, f2]))

def permute(

self, nets, k)

def permute(self, nets, k):
    return list(itertools.permutations(nets, k))

def relation_layer(

self, net)

def relation_layer(self, net):
    ops = self.ops
    #hack
    shape = ops.shape(net)
    input_size = shape[1]*shape[2]*shape[3]
    netlist = self.split_by_width_height(net)
    permutations = self.permute(netlist, 2)
    permutations = self.fully_connected_from_list(permutations)
    net = ops.concat(permutations, axis=3)
    #hack
    bs = ops.shape(net)[0]
    net = ops.reshape(net, [bs, -1])
    net = ops.linear(net, input_size)
    net = ops.reshape(net, shape)
    return net

def required(

self)

Return a list of required config strings and a ValidationException will be thrown if any are missing.

Example: python class MyComponent(GANComponent): def required(self): "learn rate is required" ["learn_rate"]

def required(self):
    return "label_smooth".split()

def reuse(

self, net)

def reuse(self, net):
    self.ops.reuse()
    net = self.build(net)
    self.ops.stop_reuse()
    return net

def sigmoid_kl_with_logits(

self, logits, targets)

def sigmoid_kl_with_logits(self, logits, targets):
   # broadcasts the same target value across the whole batch
   # this is implemented so awkwardly because tensorflow lacks an x log x op
   assert isinstance(targets, float)
   if targets in [0., 1.]:
     entropy = 0.
   else:
     entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets)
     return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy

def split_batch(

self, net, count=2)

Discriminators return stacked results (on axis 0).

This splits the results. Returns [d_real, d_fake]

def split_batch(self, net, count=2):
    """ 
    Discriminators return stacked results (on axis 0).  
    
    This splits the results.  Returns [d_real, d_fake]
    """
    ops = self.ops or self.gan.ops
    s = ops.shape(net)
    bs = s[0]
    nets = []
    net = ops.reshape(net, [bs, -1])
    start = [0 for x in ops.shape(net)]
    for i in range(count):
        size = [bs//count] + [x for x in ops.shape(net)[1:]]
        nets.append(ops.slice(net, start, size))
        start[0] += bs//count
    return nets

def split_by_width_height(

self, net)

def split_by_width_height(self, net):
    elems = []
    ops = self.gan.ops
    shape = ops.shape(net)
    bs = shape[0]
    height = shape[1]
    width = shape[2]
    for i in range(width):
        for j in range(height):
            elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
    return elems

def validate(

self)

Validates a GANComponent. Return an array of error messages. Empty array [] means success.

def validate(self):
    """
    Validates a GANComponent.  Return an array of error messages. Empty array `[]` means success.
    """
    errors = []
    required = self.required()
    for argument in required:
        if(self.config.__getattr__(argument) == None):
            errors.append("`"+argument+"` required")
    if(self.gan is None):
        errors.append("GANComponent constructed without GAN")
    return errors

def variables(

self)

All variables associated with this component.

def variables(self):
    """
        All variables associated with this component.
    """
    return self.ops.variables()

def weights(

self)

The weights of the GAN component.

def weights(self):
    """
        The weights of the GAN component.
    """
    return self.ops.weights