Top

hypergan.gans.alpha_gan module

import importlib
import json
import numpy as np
import os
import sys
import time
import uuid
import copy

from hypergan.discriminators import *
from hypergan.encoders import *
from hypergan.generators import *
from hypergan.inputs import *
from hypergan.samplers import *
from hypergan.trainers import *

import hyperchamber as hc
from hyperchamber import Config
from hypergan.ops import TensorflowOps
import tensorflow as tf
import hypergan as hg

from hypergan.gan_component import ValidationException, GANComponent
from .base_gan import BaseGAN

from hypergan.discriminators.fully_connected_discriminator import FullyConnectedDiscriminator
from hypergan.encoders.uniform_encoder import UniformEncoder
from hypergan.trainers.multi_step_trainer import MultiStepTrainer

class AlphaGAN(BaseGAN):
    """ 
    """
    def __init__(self, *args, **kwargs):
        BaseGAN.__init__(self, *args, **kwargs)
        self.discriminator = None
        self.encoder = None
        self.generator = None
        self.loss = None
        self.trainer = None
        self.session = None

    def required(self):
        return "generator discriminator z_discriminator g_encoder".split()

    def create(self):
        BaseGAN.create(self)
        if self.session is None: 
            self.session = self.ops.new_session(self.ops_config)
        with tf.device(self.device):
            config = self.config
            ops = self.ops

            g_encoder = dict(config.g_encoder or config.discriminator)
            encoder = self.create_component(g_encoder)
            encoder.ops.describe("g_encoder")
            encoder.create(self.inputs.x)
            encoder.z = tf.zeros(0)
            if(len(encoder.sample.get_shape()) == 2):
                s = ops.shape(encoder.sample)
                encoder.sample = tf.reshape(encoder.sample, [s[0],s[1], 1, 1])

            z_discriminator = dict(config.z_discriminator or config.discriminator)
            z_discriminator['layer_filter']=None

            encoder_discriminator = self.create_component(z_discriminator)
            encoder_discriminator.ops.describe("z_discriminator")
            standard_discriminator = self.create_component(config.discriminator)
            standard_discriminator.ops.describe("discriminator")

            #encoder.sample = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0], -1])
            uniform_encoder_config = config.encoder
            z_size = 1
            for size in ops.shape(encoder.sample)[1:]:
                z_size *= size
            uniform_encoder_config.z = z_size
            uniform_encoder = UniformEncoder(self, uniform_encoder_config)
            uniform_encoder.create()

            self.generator = self.create_component(config.generator)

            z = uniform_encoder.sample
            x = self.inputs.x

            # project the output of the autoencoder
            projection_input = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0],-1])
            projections = []
            for projection in uniform_encoder.config.projections:
                projection = uniform_encoder.lookup(projection)(uniform_encoder.config, self.gan, projection_input)
                projection = ops.reshape(projection, ops.shape(encoder.sample))
                projections.append(projection)
            z_hat = tf.concat(axis=3, values=projections)

            z = ops.reshape(z, ops.shape(z_hat))
            # end encoding

            g = self.generator.create(z)
            sample = self.generator.sample
            self.uniform_sample = self.generator.sample
            x_hat = self.generator.reuse(z_hat)

            encoder_discriminator.create(x=z, g=z_hat)

            eloss = dict(config.loss)
            eloss['gradient_penalty'] = False
            encoder_loss = self.create_component(eloss, discriminator = encoder_discriminator)
            encoder_loss.create()

            stacked_xg = ops.concat([x, x_hat, g], axis=0)
            standard_discriminator.create(stacked_xg)

            standard_loss = self.create_component(config.loss, discriminator = standard_discriminator)
            standard_loss.create(split=3)

            self.trainer = self.create_component(config.trainer)

            #loss terms
            distance = config.distance or ops.lookup('l1_distance')
            cycloss = tf.reduce_mean(distance(self.inputs.x,x_hat))
            cycloss_lambda = config.cycloss_lambda
            if cycloss_lambda is None:
                cycloss_lambda = 10
            cycloss *= cycloss_lambda
            loss1=('generator', cycloss + encoder_loss.g_loss)
            loss2=('generator', cycloss + standard_loss.g_loss)
            loss3=('discriminator', standard_loss.d_loss)
            loss4=('discriminator', encoder_loss.d_loss)

            var_lists = []
            var_lists.append(encoder.variables())
            var_lists.append(self.generator.variables())
            var_lists.append(standard_discriminator.variables())
            var_lists.append(encoder_discriminator.variables())

            metrics = []
            metrics.append(encoder_loss.metrics)
            metrics.append(standard_loss.metrics)
            metrics.append(None)
            metrics.append(None)

            # trainer

            self.trainer = MultiStepTrainer(self, self.config.trainer, [loss1,loss2,loss3,loss4], var_lists=var_lists, metrics=metrics)
            self.trainer.create()

            self.session.run(tf.global_variables_initializer())

            self.encoder = encoder
            self.uniform_encoder = uniform_encoder


    def step(self, feed_dict={}):
        return self.trainer.step(feed_dict)

Classes

class AlphaGAN

class AlphaGAN(BaseGAN):
    """ 
    """
    def __init__(self, *args, **kwargs):
        BaseGAN.__init__(self, *args, **kwargs)
        self.discriminator = None
        self.encoder = None
        self.generator = None
        self.loss = None
        self.trainer = None
        self.session = None

    def required(self):
        return "generator discriminator z_discriminator g_encoder".split()

    def create(self):
        BaseGAN.create(self)
        if self.session is None: 
            self.session = self.ops.new_session(self.ops_config)
        with tf.device(self.device):
            config = self.config
            ops = self.ops

            g_encoder = dict(config.g_encoder or config.discriminator)
            encoder = self.create_component(g_encoder)
            encoder.ops.describe("g_encoder")
            encoder.create(self.inputs.x)
            encoder.z = tf.zeros(0)
            if(len(encoder.sample.get_shape()) == 2):
                s = ops.shape(encoder.sample)
                encoder.sample = tf.reshape(encoder.sample, [s[0],s[1], 1, 1])

            z_discriminator = dict(config.z_discriminator or config.discriminator)
            z_discriminator['layer_filter']=None

            encoder_discriminator = self.create_component(z_discriminator)
            encoder_discriminator.ops.describe("z_discriminator")
            standard_discriminator = self.create_component(config.discriminator)
            standard_discriminator.ops.describe("discriminator")

            #encoder.sample = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0], -1])
            uniform_encoder_config = config.encoder
            z_size = 1
            for size in ops.shape(encoder.sample)[1:]:
                z_size *= size
            uniform_encoder_config.z = z_size
            uniform_encoder = UniformEncoder(self, uniform_encoder_config)
            uniform_encoder.create()

            self.generator = self.create_component(config.generator)

            z = uniform_encoder.sample
            x = self.inputs.x

            # project the output of the autoencoder
            projection_input = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0],-1])
            projections = []
            for projection in uniform_encoder.config.projections:
                projection = uniform_encoder.lookup(projection)(uniform_encoder.config, self.gan, projection_input)
                projection = ops.reshape(projection, ops.shape(encoder.sample))
                projections.append(projection)
            z_hat = tf.concat(axis=3, values=projections)

            z = ops.reshape(z, ops.shape(z_hat))
            # end encoding

            g = self.generator.create(z)
            sample = self.generator.sample
            self.uniform_sample = self.generator.sample
            x_hat = self.generator.reuse(z_hat)

            encoder_discriminator.create(x=z, g=z_hat)

            eloss = dict(config.loss)
            eloss['gradient_penalty'] = False
            encoder_loss = self.create_component(eloss, discriminator = encoder_discriminator)
            encoder_loss.create()

            stacked_xg = ops.concat([x, x_hat, g], axis=0)
            standard_discriminator.create(stacked_xg)

            standard_loss = self.create_component(config.loss, discriminator = standard_discriminator)
            standard_loss.create(split=3)

            self.trainer = self.create_component(config.trainer)

            #loss terms
            distance = config.distance or ops.lookup('l1_distance')
            cycloss = tf.reduce_mean(distance(self.inputs.x,x_hat))
            cycloss_lambda = config.cycloss_lambda
            if cycloss_lambda is None:
                cycloss_lambda = 10
            cycloss *= cycloss_lambda
            loss1=('generator', cycloss + encoder_loss.g_loss)
            loss2=('generator', cycloss + standard_loss.g_loss)
            loss3=('discriminator', standard_loss.d_loss)
            loss4=('discriminator', encoder_loss.d_loss)

            var_lists = []
            var_lists.append(encoder.variables())
            var_lists.append(self.generator.variables())
            var_lists.append(standard_discriminator.variables())
            var_lists.append(encoder_discriminator.variables())

            metrics = []
            metrics.append(encoder_loss.metrics)
            metrics.append(standard_loss.metrics)
            metrics.append(None)
            metrics.append(None)

            # trainer

            self.trainer = MultiStepTrainer(self, self.config.trainer, [loss1,loss2,loss3,loss4], var_lists=var_lists, metrics=metrics)
            self.trainer.create()

            self.session.run(tf.global_variables_initializer())

            self.encoder = encoder
            self.uniform_encoder = uniform_encoder


    def step(self, feed_dict={}):
        return self.trainer.step(feed_dict)

Ancestors (in MRO)

  • AlphaGAN
  • hypergan.gans.base_gan.BaseGAN
  • hypergan.gan_component.GANComponent
  • builtins.object

Static methods

def __init__(

self, *args, **kwargs)

Initialized a new GAN.

def __init__(self, *args, **kwargs):
    BaseGAN.__init__(self, *args, **kwargs)
    self.discriminator = None
    self.encoder = None
    self.generator = None
    self.loss = None
    self.trainer = None
    self.session = None

def batch_size(

self)

def batch_size(self):
    if self._batch_size:
        return self._batch_size
    if self.inputs == None:
        raise ValidationException("gan.batch_size() requested but no inputs provided")
    return self.ops.shape(self.inputs.x)[0]

def biases(

self)

Biases of the GAN component.

def biases(self):
    """
        Biases of the GAN component.
    """
    return self.ops.biases

def channels(

self)

def channels(self):
    if self._channels:
        return self._channels
    if self.inputs == None:
        raise ValidationException("gan.channels() requested but no inputs provided")
    return self.ops.shape(self.inputs.x)[-1]

def create(

self)

def create(self):
    BaseGAN.create(self)
    if self.session is None: 
        self.session = self.ops.new_session(self.ops_config)
    with tf.device(self.device):
        config = self.config
        ops = self.ops
        g_encoder = dict(config.g_encoder or config.discriminator)
        encoder = self.create_component(g_encoder)
        encoder.ops.describe("g_encoder")
        encoder.create(self.inputs.x)
        encoder.z = tf.zeros(0)
        if(len(encoder.sample.get_shape()) == 2):
            s = ops.shape(encoder.sample)
            encoder.sample = tf.reshape(encoder.sample, [s[0],s[1], 1, 1])
        z_discriminator = dict(config.z_discriminator or config.discriminator)
        z_discriminator['layer_filter']=None
        encoder_discriminator = self.create_component(z_discriminator)
        encoder_discriminator.ops.describe("z_discriminator")
        standard_discriminator = self.create_component(config.discriminator)
        standard_discriminator.ops.describe("discriminator")
        #encoder.sample = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0], -1])
        uniform_encoder_config = config.encoder
        z_size = 1
        for size in ops.shape(encoder.sample)[1:]:
            z_size *= size
        uniform_encoder_config.z = z_size
        uniform_encoder = UniformEncoder(self, uniform_encoder_config)
        uniform_encoder.create()
        self.generator = self.create_component(config.generator)
        z = uniform_encoder.sample
        x = self.inputs.x
        # project the output of the autoencoder
        projection_input = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0],-1])
        projections = []
        for projection in uniform_encoder.config.projections:
            projection = uniform_encoder.lookup(projection)(uniform_encoder.config, self.gan, projection_input)
            projection = ops.reshape(projection, ops.shape(encoder.sample))
            projections.append(projection)
        z_hat = tf.concat(axis=3, values=projections)
        z = ops.reshape(z, ops.shape(z_hat))
        # end encoding
        g = self.generator.create(z)
        sample = self.generator.sample
        self.uniform_sample = self.generator.sample
        x_hat = self.generator.reuse(z_hat)
        encoder_discriminator.create(x=z, g=z_hat)
        eloss = dict(config.loss)
        eloss['gradient_penalty'] = False
        encoder_loss = self.create_component(eloss, discriminator = encoder_discriminator)
        encoder_loss.create()
        stacked_xg = ops.concat([x, x_hat, g], axis=0)
        standard_discriminator.create(stacked_xg)
        standard_loss = self.create_component(config.loss, discriminator = standard_discriminator)
        standard_loss.create(split=3)
        self.trainer = self.create_component(config.trainer)
        #loss terms
        distance = config.distance or ops.lookup('l1_distance')
        cycloss = tf.reduce_mean(distance(self.inputs.x,x_hat))
        cycloss_lambda = config.cycloss_lambda
        if cycloss_lambda is None:
            cycloss_lambda = 10
        cycloss *= cycloss_lambda
        loss1=('generator', cycloss + encoder_loss.g_loss)
        loss2=('generator', cycloss + standard_loss.g_loss)
        loss3=('discriminator', standard_loss.d_loss)
        loss4=('discriminator', encoder_loss.d_loss)
        var_lists = []
        var_lists.append(encoder.variables())
        var_lists.append(self.generator.variables())
        var_lists.append(standard_discriminator.variables())
        var_lists.append(encoder_discriminator.variables())
        metrics = []
        metrics.append(encoder_loss.metrics)
        metrics.append(standard_loss.metrics)
        metrics.append(None)
        metrics.append(None)
        # trainer
        self.trainer = MultiStepTrainer(self, self.config.trainer, [loss1,loss2,loss3,loss4], var_lists=var_lists, metrics=metrics)
        self.trainer.create()
        self.session.run(tf.global_variables_initializer())
        self.encoder = encoder
        self.uniform_encoder = uniform_encoder

def create_component(

self, defn, *args, **kw_args)

def create_component(self, defn, *args, **kw_args):
    if defn == None:
        return None
    if defn['class'] == None:
        raise ValidationException("Component definition is missing '" + name + "'")
    gan_component = defn['class'](self, defn, *args, **kw_args)
    self.components.append(gan_component)
    return gan_component

def create_ops(

self, config)

Create the ops object as self.ops. Also looks up config

def create_ops(self, config):
    """
    Create the ops object as `self.ops`.  Also looks up config
    """
    if self.gan is None:
        return
    if self.gan.ops_backend is None:
        return
    self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
    self.config = self.gan.ops.lookup(config)

def fully_connected_from_list(

self, nets)

def fully_connected_from_list(self, nets):
    results = []
    ops = self.ops
    for net, net2 in nets:
        net = ops.concat([net, net2], axis=3)
        shape = ops.shape(net)
        bs = shape[0]
        net = ops.reshape(net, [bs, -1])
        features = ops.shape(net)[1]
        net = ops.linear(net, features)
        #net = self.layer_regularizer(net)
        net = ops.lookup('lrelu')(net)
        #net = ops.linear(net, features)
        net = ops.reshape(net, shape)
        results.append(net)
    return results

def get_config_value(

self, symbol)

def get_config_value(self, symbol):
    if symbol in self.config:
        config = hc.Config(hc.lookup_functions(self.config[symbol]))
        return config
    return None

def height(

self)

def height(self):
    if self._height:
        return self._height
    if self.inputs == None:
        raise ValidationException("gan.height() requested but no inputs provided")
    return self.ops.shape(self.inputs.x)[1]

def layer_regularizer(

self, net)

def layer_regularizer(self, net):
    symbol = self.config.layer_regularizer
    op = self.gan.ops.lookup(symbol)
    if op:
        net = op(self, net)
    return net

def load(

self, save_file)

def load(self, save_file):
    save_file = os.path.expanduser(save_file)
    if os.path.isfile(save_file) or os.path.isfile(save_file + ".index" ):
        print("[hypergan] |= Loading network from "+ save_file)
        dir = os.path.dirname(save_file)
        print("[hypergan] |= Loading checkpoint from "+ dir)
        ckpt = tf.train.get_checkpoint_state(os.path.expanduser(dir))
        if ckpt and ckpt.model_checkpoint_path:
            saver = tf.train.Saver()
            saver.restore(self.session, save_file)
            loadedFromSave = True
            return True
        else:
            return False
    else:
        return False

def permute(

self, nets, k)

def permute(self, nets, k):
    return list(itertools.permutations(nets, k))

def relation_layer(

self, net)

def relation_layer(self, net):
    ops = self.ops
    #hack
    shape = ops.shape(net)
    input_size = shape[1]*shape[2]*shape[3]
    netlist = self.split_by_width_height(net)
    permutations = self.permute(netlist, 2)
    permutations = self.fully_connected_from_list(permutations)
    net = ops.concat(permutations, axis=3)
    #hack
    bs = ops.shape(net)[0]
    net = ops.reshape(net, [bs, -1])
    net = ops.linear(net, input_size)
    net = ops.reshape(net, shape)
    return net

def required(

self)

Return a list of required config strings and a ValidationException will be thrown if any are missing.

Example: python class MyComponent(GANComponent): def required(self): "learn rate is required" ["learn_rate"]

def required(self):
    return "generator discriminator z_discriminator g_encoder".split()

def reuse(

self, net)

def reuse(self, net):
    self.ops.reuse()
    net = self.build(net)
    self.ops.stop_reuse()
    return net

def save(

self, save_file)

def save(self, save_file):
    print("[hypergan] Saving network to ", save_file)
    os.makedirs(os.path.expanduser(os.path.dirname(save_file)), exist_ok=True)
    saver = tf.train.Saver()
    saver.save(self.session, save_file)

def split_batch(

self, net, count=2)

Discriminators return stacked results (on axis 0).

This splits the results. Returns [d_real, d_fake]

def split_batch(self, net, count=2):
    """ 
    Discriminators return stacked results (on axis 0).  
    
    This splits the results.  Returns [d_real, d_fake]
    """
    ops = self.ops or self.gan.ops
    s = ops.shape(net)
    bs = s[0]
    nets = []
    net = ops.reshape(net, [bs, -1])
    start = [0 for x in ops.shape(net)]
    for i in range(count):
        size = [bs//count] + [x for x in ops.shape(net)[1:]]
        nets.append(ops.slice(net, start, size))
        start[0] += bs//count
    return nets

def split_by_width_height(

self, net)

def split_by_width_height(self, net):
    elems = []
    ops = self.gan.ops
    shape = ops.shape(net)
    bs = shape[0]
    height = shape[1]
    width = shape[2]
    for i in range(width):
        for j in range(height):
            elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
    return elems

def step(

self, feed_dict={})

def step(self, feed_dict={}):
    return self.trainer.step(feed_dict)

def validate(

self)

Validates a GANComponent. Return an array of error messages. Empty array [] means success.

def validate(self):
    """
    Validates a GANComponent.  Return an array of error messages. Empty array `[]` means success.
    """
    errors = []
    required = self.required()
    for argument in required:
        if(self.config.__getattr__(argument) == None):
            errors.append("`"+argument+"` required")
    if(self.gan is None):
        errors.append("GANComponent constructed without GAN")
    return errors

def variables(

self)

All variables associated with this component.

def variables(self):
    """
        All variables associated with this component.
    """
    return self.ops.variables()

def weights(

self)

The weights of the GAN component.

def weights(self):
    """
        The weights of the GAN component.
    """
    return self.ops.weights

def width(

self)

def width(self):
    if self._width:
        return self._width
    if self.inputs == None:
        raise ValidationException("gan.width() requested but no inputs provided")
    return self.ops.shape(self.inputs.x)[2]

Instance variables

var discriminator

var encoder

var generator

var loss

var session

var trainer