Top

hypergan.generators.resize_conv_generator module

import tensorflow as tf
import numpy as np
import hyperchamber as hc
from hypergan.generators.common import *

from .base_generator import BaseGenerator

class ResizeConvGenerator(BaseGenerator):

    def required(self):
        return "final_depth activation depth_increase".split()

    def depths(self, initial_width=4):
        gan = self.gan
        ops = self.ops
        config = self.config
        final_depth = config.final_depth-config.depth_increase
        depths = []

        target_w = gan.width()

        w = initial_width
        #ontehuas
        i = 0

        depths.append(final_depth)
        while w < target_w:
            w*=2
            i+=1
            depths.append(final_depth + i*config.depth_increase)
        depths = depths[1:]
        depths.reverse()
        return depths

    def build(self, net):
        gan = self.gan
        ops = self.ops
        config = self.config

        nets = []

        activation = ops.lookup(config.activation)
        final_activation = ops.lookup(config.final_activation)
        block = config.block or standard_block

        if config.skip_linear:
            net = self.layer_filter(net)
            if config.concat_linear:
                size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
                net2 = tf.reshape(net, [ops.shape(net)[0], -1])
                net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
                net2 = ops.linear(net2, size)
                net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
                net2 = self.layer_regularizer(net2)
                net2 = config.activation(net2)
                net = tf.concat([net, net2], axis=3)
            net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
            for i in range(config.extra_layers or 0):
                net = self.layer_regularizer(net)
                net = activation(net)
                net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
        else:
            net = ops.reshape(net, [ops.shape(net)[0], -1])
            primes = config.initial_dimensions or [4, 4]
            depths = self.depths(primes[0])
            initial_depth = depths[0]
            new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
            net = ops.linear(net, initial_depth*primes[0]*primes[1])
            net = ops.reshape(net, new_shape)

        shape = ops.shape(net)

        depths = self.depths(initial_width = shape[1])
        print("[generator] Initial depth", shape)

        if config.relation_layer:
            net = self.layer_regularizer(net)
            net = activation(net)
            net = self.relation_layer(net)
            print("[generator] relational layer", net)
        else:
            pass

        depth_reduction = np.float32(config.depth_reduction)
        shape = ops.shape(net)

        net = self.layer_filter(net)
        for i, depth in enumerate(depths[1:]):
            s = ops.shape(net)
            resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
            net = self.layer_regularizer(net)
            net = activation(net)
            if block != 'deconv':
                net = ops.resize_images(net, resize, config.resize_image_type or 1)
                net = block(self, net, depth, filter=3)
            else:
                net = ops.deconv2d(net, 5, 5, 2, 2, depth)


            size = resize[0]*resize[1]*depth
            print("[generator] layer", net, size)

        net = self.layer_regularizer(net)
        net = activation(net)
        resize = [gan.height(), gan.width()]

        if block != 'deconv':
            net = ops.resize_images(net, resize, config.resize_image_type or 1)
            net = self.layer_filter(net)
            net = block(self, net, gan.channels(), filter=config.final_filter or 3)
        else:
            net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())


        if final_activation:
            net = self.layer_regularizer(net)
            net = final_activation(net)

        self.sample = net
        return self.sample

Classes

class ResizeConvGenerator

GANComponents are pluggable pieces within a GAN.

GAN objects are also GANComponents.

class ResizeConvGenerator(BaseGenerator):

    def required(self):
        return "final_depth activation depth_increase".split()

    def depths(self, initial_width=4):
        gan = self.gan
        ops = self.ops
        config = self.config
        final_depth = config.final_depth-config.depth_increase
        depths = []

        target_w = gan.width()

        w = initial_width
        #ontehuas
        i = 0

        depths.append(final_depth)
        while w < target_w:
            w*=2
            i+=1
            depths.append(final_depth + i*config.depth_increase)
        depths = depths[1:]
        depths.reverse()
        return depths

    def build(self, net):
        gan = self.gan
        ops = self.ops
        config = self.config

        nets = []

        activation = ops.lookup(config.activation)
        final_activation = ops.lookup(config.final_activation)
        block = config.block or standard_block

        if config.skip_linear:
            net = self.layer_filter(net)
            if config.concat_linear:
                size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
                net2 = tf.reshape(net, [ops.shape(net)[0], -1])
                net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
                net2 = ops.linear(net2, size)
                net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
                net2 = self.layer_regularizer(net2)
                net2 = config.activation(net2)
                net = tf.concat([net, net2], axis=3)
            net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
            for i in range(config.extra_layers or 0):
                net = self.layer_regularizer(net)
                net = activation(net)
                net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
        else:
            net = ops.reshape(net, [ops.shape(net)[0], -1])
            primes = config.initial_dimensions or [4, 4]
            depths = self.depths(primes[0])
            initial_depth = depths[0]
            new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
            net = ops.linear(net, initial_depth*primes[0]*primes[1])
            net = ops.reshape(net, new_shape)

        shape = ops.shape(net)

        depths = self.depths(initial_width = shape[1])
        print("[generator] Initial depth", shape)

        if config.relation_layer:
            net = self.layer_regularizer(net)
            net = activation(net)
            net = self.relation_layer(net)
            print("[generator] relational layer", net)
        else:
            pass

        depth_reduction = np.float32(config.depth_reduction)
        shape = ops.shape(net)

        net = self.layer_filter(net)
        for i, depth in enumerate(depths[1:]):
            s = ops.shape(net)
            resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
            net = self.layer_regularizer(net)
            net = activation(net)
            if block != 'deconv':
                net = ops.resize_images(net, resize, config.resize_image_type or 1)
                net = block(self, net, depth, filter=3)
            else:
                net = ops.deconv2d(net, 5, 5, 2, 2, depth)


            size = resize[0]*resize[1]*depth
            print("[generator] layer", net, size)

        net = self.layer_regularizer(net)
        net = activation(net)
        resize = [gan.height(), gan.width()]

        if block != 'deconv':
            net = ops.resize_images(net, resize, config.resize_image_type or 1)
            net = self.layer_filter(net)
            net = block(self, net, gan.channels(), filter=config.final_filter or 3)
        else:
            net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())


        if final_activation:
            net = self.layer_regularizer(net)
            net = final_activation(net)

        self.sample = net
        return self.sample

Ancestors (in MRO)

  • ResizeConvGenerator
  • hypergan.generators.base_generator.BaseGenerator
  • hypergan.gan_component.GANComponent
  • builtins.object

Static methods

def __init__(

self, gan, config)

Initializes a gan component based on a gan and a config dictionary.

Different components require different config variables.

A ValidationException is raised if the GAN component configuration fails to validate.

def __init__(self, gan, config):
    """
    Initializes a gan component based on a `gan` and a `config` dictionary.
    Different components require different config variables.  
    A `ValidationException` is raised if the GAN component configuration fails to validate.
    """
    self.gan = gan
    self.config = hc.Config(config)
    errors = self.validate()
    if errors != []:
        raise ValidationException(self.__class__.__name__+": " +"\n".join(errors))
    self.create_ops(config)

def biases(

self)

Biases of the GAN component.

def biases(self):
    """
        Biases of the GAN component.
    """
    return self.ops.biases

def build(

self, net)

def build(self, net):
    gan = self.gan
    ops = self.ops
    config = self.config
    nets = []
    activation = ops.lookup(config.activation)
    final_activation = ops.lookup(config.final_activation)
    block = config.block or standard_block
    if config.skip_linear:
        net = self.layer_filter(net)
        if config.concat_linear:
            size = ops.shape(net)[1]*ops.shape(net)[2]*config.concat_linear_filters
            net2 = tf.reshape(net, [ops.shape(net)[0], -1])
            net2 = tf.slice(net2, [0,0], [ops.shape(net)[0], config.concat_linear])
            net2 = ops.linear(net2, size)
            net2 = tf.reshape(net2, [ops.shape(net)[0], ops.shape(net)[1], ops.shape(net)[2], config.concat_linear_filters])
            net2 = self.layer_regularizer(net2)
            net2 = config.activation(net2)
            net = tf.concat([net, net2], axis=3)
        net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
        for i in range(config.extra_layers or 0):
            net = self.layer_regularizer(net)
            net = activation(net)
            net = ops.conv2d(net, 3, 3, 1, 1, ops.shape(net)[3]//(config.extra_layers_reduction or 1))
    else:
        net = ops.reshape(net, [ops.shape(net)[0], -1])
        primes = config.initial_dimensions or [4, 4]
        depths = self.depths(primes[0])
        initial_depth = depths[0]
        new_shape = [ops.shape(net)[0], primes[0], primes[1], initial_depth]
        net = ops.linear(net, initial_depth*primes[0]*primes[1])
        net = ops.reshape(net, new_shape)
    shape = ops.shape(net)
    depths = self.depths(initial_width = shape[1])
    print("[generator] Initial depth", shape)
    if config.relation_layer:
        net = self.layer_regularizer(net)
        net = activation(net)
        net = self.relation_layer(net)
        print("[generator] relational layer", net)
    else:
        pass
    depth_reduction = np.float32(config.depth_reduction)
    shape = ops.shape(net)
    net = self.layer_filter(net)
    for i, depth in enumerate(depths[1:]):
        s = ops.shape(net)
        resize = [min(s[1]*2, gan.height()), min(s[2]*2, gan.width())]
        net = self.layer_regularizer(net)
        net = activation(net)
        if block != 'deconv':
            net = ops.resize_images(net, resize, config.resize_image_type or 1)
            net = block(self, net, depth, filter=3)
        else:
            net = ops.deconv2d(net, 5, 5, 2, 2, depth)
        size = resize[0]*resize[1]*depth
        print("[generator] layer", net, size)
    net = self.layer_regularizer(net)
    net = activation(net)
    resize = [gan.height(), gan.width()]
    if block != 'deconv':
        net = ops.resize_images(net, resize, config.resize_image_type or 1)
        net = self.layer_filter(net)
        net = block(self, net, gan.channels(), filter=config.final_filter or 3)
    else:
        net = ops.deconv2d(net, 5, 5, 2, 2, gan.channels())
    if final_activation:
        net = self.layer_regularizer(net)
        net = final_activation(net)
    self.sample = net
    return self.sample

def create(

self, sample=None)

def create(self, sample=None):
    gan = self.gan
    ops = self.ops
    if sample is None:
        sample = gan.encoder.sample
    return self.build(sample)

def create_ops(

self, config)

Create the ops object as self.ops. Also looks up config

def create_ops(self, config):
    """
    Create the ops object as `self.ops`.  Also looks up config
    """
    if self.gan is None:
        return
    if self.gan.ops_backend is None:
        return
    self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device)
    self.config = self.gan.ops.lookup(config)

def depths(

self, initial_width=4)

def depths(self, initial_width=4):
    gan = self.gan
    ops = self.ops
    config = self.config
    final_depth = config.final_depth-config.depth_increase
    depths = []
    target_w = gan.width()
    w = initial_width
    #ontehuas
    i = 0
    depths.append(final_depth)
    while w < target_w:
        w*=2
        i+=1
        depths.append(final_depth + i*config.depth_increase)
    depths = depths[1:]
    depths.reverse()
    return depths

def fully_connected_from_list(

self, nets)

def fully_connected_from_list(self, nets):
    results = []
    ops = self.ops
    for net, net2 in nets:
        net = ops.concat([net, net2], axis=3)
        shape = ops.shape(net)
        bs = shape[0]
        net = ops.reshape(net, [bs, -1])
        features = ops.shape(net)[1]
        net = ops.linear(net, features)
        #net = self.layer_regularizer(net)
        net = ops.lookup('lrelu')(net)
        #net = ops.linear(net, features)
        net = ops.reshape(net, shape)
        results.append(net)
    return results

def layer_filter(

self, net)

def layer_filter(self, net):
    ops = self.ops
    gan = self.gan
    config = self.config
    if config.layer_filter:
        print("[base generator] applying layer filter", config['layer_filter'])
        fltr = config.layer_filter(gan, self.config, net)
        if fltr is not None:
            net = ops.concat(axis=3, values=[net, fltr])
    return net

def layer_regularizer(

self, net)

def layer_regularizer(self, net):
    symbol = self.config.layer_regularizer
    op = self.gan.ops.lookup(symbol)
    if op:
        net = op(self, net)
    return net

def permute(

self, nets, k)

def permute(self, nets, k):
    return list(itertools.permutations(nets, k))

def relation_layer(

self, net)

def relation_layer(self, net):
    ops = self.ops
    #hack
    shape = ops.shape(net)
    input_size = shape[1]*shape[2]*shape[3]
    netlist = self.split_by_width_height(net)
    permutations = self.permute(netlist, 2)
    permutations = self.fully_connected_from_list(permutations)
    net = ops.concat(permutations, axis=3)
    #hack
    bs = ops.shape(net)[0]
    net = ops.reshape(net, [bs, -1])
    net = ops.linear(net, input_size)
    net = ops.reshape(net, shape)
    return net

def required(

self)

Return a list of required config strings and a ValidationException will be thrown if any are missing.

Example: python class MyComponent(GANComponent): def required(self): "learn rate is required" ["learn_rate"]

def required(self):
    return "final_depth activation depth_increase".split()

def reuse(

self, net)

def reuse(self, net):
    self.ops.reuse()
    net = self.build(net)
    self.ops.stop_reuse()
    return net

def split_batch(

self, net, count=2)

Discriminators return stacked results (on axis 0).

This splits the results. Returns [d_real, d_fake]

def split_batch(self, net, count=2):
    """ 
    Discriminators return stacked results (on axis 0).  
    
    This splits the results.  Returns [d_real, d_fake]
    """
    ops = self.ops or self.gan.ops
    s = ops.shape(net)
    bs = s[0]
    nets = []
    net = ops.reshape(net, [bs, -1])
    start = [0 for x in ops.shape(net)]
    for i in range(count):
        size = [bs//count] + [x for x in ops.shape(net)[1:]]
        nets.append(ops.slice(net, start, size))
        start[0] += bs//count
    return nets

def split_by_width_height(

self, net)

def split_by_width_height(self, net):
    elems = []
    ops = self.gan.ops
    shape = ops.shape(net)
    bs = shape[0]
    height = shape[1]
    width = shape[2]
    for i in range(width):
        for j in range(height):
            elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1]))
    return elems

def validate(

self)

Validates a GANComponent. Return an array of error messages. Empty array [] means success.

def validate(self):
    """
    Validates a GANComponent.  Return an array of error messages. Empty array `[]` means success.
    """
    errors = []
    required = self.required()
    for argument in required:
        if(self.config.__getattr__(argument) == None):
            errors.append("`"+argument+"` required")
    if(self.gan is None):
        errors.append("GANComponent constructed without GAN")
    return errors

def variables(

self)

All variables associated with this component.

def variables(self):
    """
        All variables associated with this component.
    """
    return self.ops.variables()

def weights(

self)

The weights of the GAN component.

def weights(self):
    """
        The weights of the GAN component.
    """
    return self.ops.weights