hypergan.losses.softmax_loss module
import tensorflow as tf
import hyperchamber as hc
from hypergan.losses.base_loss import BaseLoss
class SoftmaxLoss(BaseLoss):
def _create(self, d_real, d_fake):
gan = self.gan
config = self.config
ops = self.gan.ops
ln_zb = tf.reduce_sum(tf.exp(-d_real), axis=1)+tf.reduce_sum(tf.exp(-d_fake), axis=1)
ln_zb = tf.log(ln_zb)
d_loss = tf.reduce_mean(d_real, axis=1) + ln_zb
g_loss = tf.reduce_mean(d_fake, axis=1) + tf.reduce_mean(d_real, axis=1) + ln_zb
d_loss = ops.reshape(d_loss, [])
g_loss = ops.reshape(g_loss, [])
return [d_loss, g_loss]
Classes
class SoftmaxLoss
GANComponents are pluggable pieces within a GAN.
GAN objects are also GANComponents.
class SoftmaxLoss(BaseLoss):
def _create(self, d_real, d_fake):
gan = self.gan
config = self.config
ops = self.gan.ops
ln_zb = tf.reduce_sum(tf.exp(-d_real), axis=1)+tf.reduce_sum(tf.exp(-d_fake), axis=1)
ln_zb = tf.log(ln_zb)
d_loss = tf.reduce_mean(d_real, axis=1) + ln_zb
g_loss = tf.reduce_mean(d_fake, axis=1) + tf.reduce_mean(d_real, axis=1) + ln_zb
d_loss = ops.reshape(d_loss, [])
g_loss = ops.reshape(g_loss, [])
return [d_loss, g_loss]
Ancestors (in MRO)
- SoftmaxLoss
- hypergan.losses.base_loss.BaseLoss
- hypergan.gan_component.GANComponent
- builtins.object
Static methods
def __init__(
self, gan, config, discriminator=None, generator=None)
Initializes a gan component based on a gan
and a config
dictionary.
Different components require different config variables.
A ValidationException
is raised if the GAN component configuration fails to validate.
def __init__(self, gan, config, discriminator=None, generator=None):
GANComponent.__init__(self, gan, config)
self.metrics = {}
self.sample = None
self.ops = None
self.discriminator = discriminator
self.generator = generator
def biases(
self)
Biases of the GAN component.
def biases(self):
"""
Biases of the GAN component.
"""
return self.ops.biases
def create(
self, split=2)
def create(self, split=2):
gan = self.gan
config = self.config
ops = self.gan.ops
if self.discriminator is None:
net = gan.discriminator.sample
else:
net = self.discriminator.sample
if split == 2:
d_real, d_fake = self.split_batch(net, split)
d_loss, g_loss = self._create(d_real, d_fake)
elif split == 3:
d_real, d_fake, d_fake2 = self.split_batch(net, split)
d_loss, g_loss = self._create(d_real, d_fake)
d_loss2, g_loss2 = self._create(d_real, d_fake2)
g_loss += g_loss2
d_loss += d_loss2
#does this double the signal of d_real?
if d_loss is not None:
d_loss = ops.squash(d_loss, tf.reduce_mean) #linear doesn't work with this, so we cant pass config.reduce
self.metrics['d_loss'] = d_loss
if config.minibatch:
d_loss += self.minibatch(net)
if config.gradient_penalty:
gp = self.gradient_penalty()
self.metrics['gradient_penalty'] = gp
print("Gradient penalty applied")
d_loss += gp
if g_loss is not None:
g_loss = ops.squash(g_loss, tf.reduce_mean)
self.metrics['g_loss'] = g_loss
self.metrics = self.metrics or sample_metrics
self.sample = [d_loss, g_loss]
self.d_loss = d_loss
self.g_loss = g_loss
return self.sample
def create_ops(
self, config)
Create the ops object as self.ops
. Also looks up config
def create_ops(self, config):
"""
Create the ops object as `self.ops`. Also looks up config
"""
if self.gan is None:
return
if self.gan.ops_backend is None:
return
self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
self.config = self.gan.ops.lookup(config)
def fully_connected_from_list(
self, nets)
def fully_connected_from_list(self, nets):
results = []
ops = self.ops
for net, net2 in nets:
net = ops.concat([net, net2], axis=3)
shape = ops.shape(net)
bs = shape[0]
net = ops.reshape(net, [bs, -1])
features = ops.shape(net)[1]
net = ops.linear(net, features)
#net = self.layer_regularizer(net)
net = ops.lookup('lrelu')(net)
#net = ops.linear(net, features)
net = ops.reshape(net, shape)
results.append(net)
return results
def gradient_penalty(
self)
def gradient_penalty(self):
config = self.config
gan = self.gan
gradient_penalty = config.gradient_penalty
if has_attr(gan.inputs, 'gradient_penalty_label'):
x = gan.inputs.gradient_penalty_label
else:
x = gan.inputs.x
generator = self.generator or gan.generator
g = generator.sample
discriminator = self.discriminator or gan.discriminator
shape = [1 for t in g.get_shape()]
shape[0] = gan.batch_size()
uniform_noise = tf.random_uniform(shape=shape,minval=0.,maxval=1.)
print("[gradient penalty] applying x:", x, "g:", g, "noise:", uniform_noise)
interpolates = x + uniform_noise * (g - x)
reused_d = discriminator.reuse(interpolates)
gradients = tf.gradients(reused_d, [interpolates])[0]
penalty = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
penalty = tf.reduce_mean(tf.square(penalty - 1.))
return float(gradient_penalty) * penalty
def layer_regularizer(
self, net)
def layer_regularizer(self, net):
symbol = self.config.layer_regularizer
op = self.gan.ops.lookup(symbol)
if op:
net = op(self, net)
return net
def minibatch(
self, net)
def minibatch(self, net):
discriminator = self.discriminator or self.gan.discriminator
ops = discriminator.ops
config = self.config
batch_size = ops.shape(net)[0]
single_batch_size = batch_size//2
n_kernels = config.minibatch_kernels or 300
dim_per_kernel = config.dim_per_kernel or 50
print("[discriminator] minibatch from", net, "to", n_kernels*dim_per_kernel)
x = ops.linear(net, n_kernels * dim_per_kernel)
activation = tf.reshape(x, (batch_size, n_kernels, dim_per_kernel))
big = np.zeros((batch_size, batch_size))
big += np.eye(batch_size)
big = tf.expand_dims(big, 1)
big = tf.cast(big,dtype=ops.dtype)
abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(activation,3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)), 2)
mask = 1. - big
masked = tf.exp(-abs_dif) * mask
def half(tens, second):
m, n, _ = tens.get_shape()
m = int(m)
n = int(n)
return tf.slice(tens, [0, 0, second * single_batch_size], [m, n, single_batch_size])
f1 = tf.reduce_sum(half(masked, 0), 2) / tf.reduce_sum(half(mask, 0))
f2 = tf.reduce_sum(half(masked, 1), 2) / tf.reduce_sum(half(mask, 1))
return ops.squash(ops.concat([f1, f2]))
def permute(
self, nets, k)
def permute(self, nets, k):
return list(itertools.permutations(nets, k))
def relation_layer(
self, net)
def relation_layer(self, net):
ops = self.ops
#hack
shape = ops.shape(net)
input_size = shape[1]*shape[2]*shape[3]
netlist = self.split_by_width_height(net)
permutations = self.permute(netlist, 2)
permutations = self.fully_connected_from_list(permutations)
net = ops.concat(permutations, axis=3)
#hack
bs = ops.shape(net)[0]
net = ops.reshape(net, [bs, -1])
net = ops.linear(net, input_size)
net = ops.reshape(net, shape)
return net
def required(
self)
Return a list of required config strings and a ValidationException
will be thrown if any are missing.
Example:
python
class MyComponent(GANComponent):
def required(self):
"learn rate is required"
["learn_rate"]
def required(self):
"""
Return a list of required config strings and a `ValidationException` will be thrown if any are missing.
Example:
```python
class MyComponent(GANComponent):
def required(self):
"learn rate is required"
["learn_rate"]
```
"""
return []
def reuse(
self, net)
def reuse(self, net):
self.ops.reuse()
net = self.build(net)
self.ops.stop_reuse()
return net
def sigmoid_kl_with_logits(
self, logits, targets)
def sigmoid_kl_with_logits(self, logits, targets):
# broadcasts the same target value across the whole batch
# this is implemented so awkwardly because tensorflow lacks an x log x op
assert isinstance(targets, float)
if targets in [0., 1.]:
entropy = 0.
else:
entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets)
return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy
def split_batch(
self, net, count=2)
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
def split_batch(self, net, count=2):
"""
Discriminators return stacked results (on axis 0).
This splits the results. Returns [d_real, d_fake]
"""
ops = self.ops or self.gan.ops
s = ops.shape(net)
bs = s[0]
nets = []
net = ops.reshape(net, [bs, -1])
start = [0 for x in ops.shape(net)]
for i in range(count):
size = [bs//count] + [x for x in ops.shape(net)[1:]]
nets.append(ops.slice(net, start, size))
start[0] += bs//count
return nets
def split_by_width_height(
self, net)
def split_by_width_height(self, net):
elems = []
ops = self.gan.ops
shape = ops.shape(net)
bs = shape[0]
height = shape[1]
width = shape[2]
for i in range(width):
for j in range(height):
elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
return elems
def validate(
self)
Validates a GANComponent. Return an array of error messages. Empty array []
means success.
def validate(self):
"""
Validates a GANComponent. Return an array of error messages. Empty array `[]` means success.
"""
errors = []
required = self.required()
for argument in required:
if(self.config.__getattr__(argument) == None):
errors.append("`"+argument+"` required")
if(self.gan is None):
errors.append("GANComponent constructed without GAN")
return errors
def variables(
self)
All variables associated with this component.
def variables(self):
"""
All variables associated with this component.
"""
return self.ops.variables()
def weights(
self)
The weights of the GAN component.
def weights(self):
"""
The weights of the GAN component.
"""
return self.ops.weights