Top

hypergan.cli module

import os
import hyperchamber as hc
import tensorflow as tf
from hypergan.gan_component import ValidationException
from . import GAN
from .inputs import *
from .viewer import GlobalViewer
from .configuration import Configuration
import hypergan as hg
import time

import fcntl
import os
import shutil
import sys

from hypergan.samplers.static_batch_sampler import StaticBatchSampler
from hypergan.samplers.batch_sampler import BatchSampler
from hypergan.samplers.grid_sampler import GridSampler
from hypergan.samplers.began_sampler import BeganSampler
from hypergan.samplers.aligned_sampler import AlignedSampler
from hypergan.samplers.autoencode_sampler import AutoencodeSampler
from hypergan.samplers.random_walk_sampler import RandomWalkSampler
from hypergan.samplers.alphagan_random_walk_sampler import AlphaganRandomWalkSampler

from hypergan.losses.supervised_loss import SupervisedLoss
from hypergan.multi_component import MultiComponent
from time import sleep

class CLI:
    def __init__(self, gan, args={}):
        self.samples = 0
        self.steps = 0
        self.gan = gan

        args = hc.Config(args)
        self.args = args

        crop =  self.args.crop

        self.config_name = self.args.config or 'default'
        self.method = args.method or 'test'
        self.total_steps = args.steps or -1
        self.sample_every = self.args.sample_every or 100

        self.sampler = CLI.sampler_for(args.sampler)(self.gan)

        self.validate()
        if self.args.save_file:
            self.save_file = self.args.save_file
        else:
            default_save_path = os.path.abspath("saves/"+self.config_name)
            self.save_file = default_save_path + "/model.ckpt"
            self.create_path(self.save_file)

        title = "[hypergan] " + self.config_name
        GlobalViewer.title = title
        GlobalViewer.enabled = self.args.viewer

    def sampler_for(name):
        samplers = {
                'static_batch': StaticBatchSampler,
                'random_walk': RandomWalkSampler,
                'alphagan_random_walk': AlphaganRandomWalkSampler,
                'batch': BatchSampler,
                'grid': GridSampler,
                'began': BeganSampler,
                'autoencode': AutoencodeSampler,
                'aligned': AlignedSampler
        }
        if name in samplers:
            return samplers[name]
        else:
            print("[hypergan] No sampler found for ", name, ".  Defaulting to StaticBatch")
            return StaticBatchSampler

    def sample(self, sample_file):
        """ Samples to a file.  Useful for visualizing the learning process.

        Use with:

             ffmpeg -i samples/grid-%06d.png -vcodec libx264 -crf 22 -threads 0 grid1-7.mp4

        to create a video of the learning process.
        """

        sample_list = self.sampler.sample(sample_file, self.args.save_samples)

        return sample_list


    def validate(self):
        if(self.sampler == None):
            raise ValidationException("No sampler found by the name '"+self.sampler_name+"'")

    def step(self):
        self.gan.step()

        if(self.steps % self.sample_every == 0):
            sample_file="samples/%06d.png" % (self.samples)
            self.create_path(sample_file)
            sample_list = self.sample(sample_file)
            if self.args.use_hc_io:
                self.gan.config['model'] = self.args.config
                hc.io.sample(self.gan.config, sample_list)

            self.samples += 1

        self.steps+=1

    def create_path(self, filename):
        return os.makedirs(os.path.expanduser(os.path.dirname(filename)), exist_ok=True)

    def build(self):
        save_file = self.args.config+".pbgraph"
        build_file = os.path.expanduser("builds/"+save_file)
        self.create_path(build_file)
        tf.train.write_graph(self.gan.session.graph, 'builds', save_file)

        print("Saved generator to ", build_file)

    def serve(self, gan):
        return gan_server(self.gan.session, config)

    def sample_forever(self):
        while True:
            sample_file="samples/%06d.png" % (self.samples)
            self.create_path(sample_file)
            self.sample(sample_file)
            self.samples += 1
            print("Sample", self.samples)
            sleep(0.2)


    def train(self):
        i=0
        if(self.args.ipython):
            fd = sys.stdin.fileno()
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

        while(i < self.total_steps or self.total_steps == -1):
            i+=1
            start_time = time.time()
            self.step()

            if (self.args.save_every != None and
                self.args.save_every != -1 and
                self.args.save_every > 0 and
                i % self.args.save_every == 0):
                print(" |= Saving network")
                self.gan.save(self.save_file)
            if self.args.ipython:
                self.check_stdin()
            end_time = time.time()

    def check_stdin(self):
        try:
            input = sys.stdin.read()
            if input[0]=="y":
                return
            from IPython import embed
            # Misc code
            embed()

        except:
            return

    def new(self):
        template = self.args.directory + '.json'
        print("[hypergan] Creating new configuration file '"+template+"' based off of '"+self.config_name+".json'")
        if os.path.isfile(template):
            raise ValidationException("File exists: " + template)
        source_configuration = Configuration.find(self.config_name+".json")
        shutil.copyfile(source_configuration, template)

        return

    def add_supervised_loss(self):
        if self.args.classloss:
            print("[discriminator] Class loss is on.  Semi-supervised learning mode activated.")
            supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss)
            self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add')
            supervised_loss.create()
            #EWW
        else:
            print("[discriminator] Class loss is off.  Unsupervised learning mode activated.")

    def run(self):
        if self.method == 'train':
            self.gan.create()
            self.add_supervised_loss()
            self.gan.session.run(tf.global_variables_initializer())

            if not self.gan.load(self.save_file):
                print("Initializing new model")
            else:
                print("Model loaded")
            tf.train.start_queue_runners(sess=self.gan.session)
            self.train()
            tf.reset_default_graph()
            self.gan.session.close()
        elif self.method == 'build':
            self.gan.create()
            if not self.gan.load(self.save_file):
                raise "Could not load model: "+ save_file
            else:
                print("Model loaded")
            self.build()
            tf.reset_default_graph()
            self.gan.session.close()
        elif self.method == 'new':
            self.new()
        elif self.method == 'sample':
            self.gan.create()
            self.add_supervised_loss()
            if not self.gan.load(self.save_file):
                print("Initializing new model")
            else:
                print("Model loaded")

            tf.train.start_queue_runners(sess=self.gan.session)
            self.sample_forever()
            tf.reset_default_graph()
            self.gan.session.close()

Classes

class CLI

class CLI:
    def __init__(self, gan, args={}):
        self.samples = 0
        self.steps = 0
        self.gan = gan

        args = hc.Config(args)
        self.args = args

        crop =  self.args.crop

        self.config_name = self.args.config or 'default'
        self.method = args.method or 'test'
        self.total_steps = args.steps or -1
        self.sample_every = self.args.sample_every or 100

        self.sampler = CLI.sampler_for(args.sampler)(self.gan)

        self.validate()
        if self.args.save_file:
            self.save_file = self.args.save_file
        else:
            default_save_path = os.path.abspath("saves/"+self.config_name)
            self.save_file = default_save_path + "/model.ckpt"
            self.create_path(self.save_file)

        title = "[hypergan] " + self.config_name
        GlobalViewer.title = title
        GlobalViewer.enabled = self.args.viewer

    def sampler_for(name):
        samplers = {
                'static_batch': StaticBatchSampler,
                'random_walk': RandomWalkSampler,
                'alphagan_random_walk': AlphaganRandomWalkSampler,
                'batch': BatchSampler,
                'grid': GridSampler,
                'began': BeganSampler,
                'autoencode': AutoencodeSampler,
                'aligned': AlignedSampler
        }
        if name in samplers:
            return samplers[name]
        else:
            print("[hypergan] No sampler found for ", name, ".  Defaulting to StaticBatch")
            return StaticBatchSampler

    def sample(self, sample_file):
        """ Samples to a file.  Useful for visualizing the learning process.

        Use with:

             ffmpeg -i samples/grid-%06d.png -vcodec libx264 -crf 22 -threads 0 grid1-7.mp4

        to create a video of the learning process.
        """

        sample_list = self.sampler.sample(sample_file, self.args.save_samples)

        return sample_list


    def validate(self):
        if(self.sampler == None):
            raise ValidationException("No sampler found by the name '"+self.sampler_name+"'")

    def step(self):
        self.gan.step()

        if(self.steps % self.sample_every == 0):
            sample_file="samples/%06d.png" % (self.samples)
            self.create_path(sample_file)
            sample_list = self.sample(sample_file)
            if self.args.use_hc_io:
                self.gan.config['model'] = self.args.config
                hc.io.sample(self.gan.config, sample_list)

            self.samples += 1

        self.steps+=1

    def create_path(self, filename):
        return os.makedirs(os.path.expanduser(os.path.dirname(filename)), exist_ok=True)

    def build(self):
        save_file = self.args.config+".pbgraph"
        build_file = os.path.expanduser("builds/"+save_file)
        self.create_path(build_file)
        tf.train.write_graph(self.gan.session.graph, 'builds', save_file)

        print("Saved generator to ", build_file)

    def serve(self, gan):
        return gan_server(self.gan.session, config)

    def sample_forever(self):
        while True:
            sample_file="samples/%06d.png" % (self.samples)
            self.create_path(sample_file)
            self.sample(sample_file)
            self.samples += 1
            print("Sample", self.samples)
            sleep(0.2)


    def train(self):
        i=0
        if(self.args.ipython):
            fd = sys.stdin.fileno()
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

        while(i < self.total_steps or self.total_steps == -1):
            i+=1
            start_time = time.time()
            self.step()

            if (self.args.save_every != None and
                self.args.save_every != -1 and
                self.args.save_every > 0 and
                i % self.args.save_every == 0):
                print(" |= Saving network")
                self.gan.save(self.save_file)
            if self.args.ipython:
                self.check_stdin()
            end_time = time.time()

    def check_stdin(self):
        try:
            input = sys.stdin.read()
            if input[0]=="y":
                return
            from IPython import embed
            # Misc code
            embed()

        except:
            return

    def new(self):
        template = self.args.directory + '.json'
        print("[hypergan] Creating new configuration file '"+template+"' based off of '"+self.config_name+".json'")
        if os.path.isfile(template):
            raise ValidationException("File exists: " + template)
        source_configuration = Configuration.find(self.config_name+".json")
        shutil.copyfile(source_configuration, template)

        return

    def add_supervised_loss(self):
        if self.args.classloss:
            print("[discriminator] Class loss is on.  Semi-supervised learning mode activated.")
            supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss)
            self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add')
            supervised_loss.create()
            #EWW
        else:
            print("[discriminator] Class loss is off.  Unsupervised learning mode activated.")

    def run(self):
        if self.method == 'train':
            self.gan.create()
            self.add_supervised_loss()
            self.gan.session.run(tf.global_variables_initializer())

            if not self.gan.load(self.save_file):
                print("Initializing new model")
            else:
                print("Model loaded")
            tf.train.start_queue_runners(sess=self.gan.session)
            self.train()
            tf.reset_default_graph()
            self.gan.session.close()
        elif self.method == 'build':
            self.gan.create()
            if not self.gan.load(self.save_file):
                raise "Could not load model: "+ save_file
            else:
                print("Model loaded")
            self.build()
            tf.reset_default_graph()
            self.gan.session.close()
        elif self.method == 'new':
            self.new()
        elif self.method == 'sample':
            self.gan.create()
            self.add_supervised_loss()
            if not self.gan.load(self.save_file):
                print("Initializing new model")
            else:
                print("Model loaded")

            tf.train.start_queue_runners(sess=self.gan.session)
            self.sample_forever()
            tf.reset_default_graph()
            self.gan.session.close()

Ancestors (in MRO)

  • CLI
  • builtins.object

Static methods

def __init__(

self, gan, args={})

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, gan, args={}):
    self.samples = 0
    self.steps = 0
    self.gan = gan
    args = hc.Config(args)
    self.args = args
    crop =  self.args.crop
    self.config_name = self.args.config or 'default'
    self.method = args.method or 'test'
    self.total_steps = args.steps or -1
    self.sample_every = self.args.sample_every or 100
    self.sampler = CLI.sampler_for(args.sampler)(self.gan)
    self.validate()
    if self.args.save_file:
        self.save_file = self.args.save_file
    else:
        default_save_path = os.path.abspath("saves/"+self.config_name)
        self.save_file = default_save_path + "/model.ckpt"
        self.create_path(self.save_file)
    title = "[hypergan] " + self.config_name
    GlobalViewer.title = title
    GlobalViewer.enabled = self.args.viewer

def add_supervised_loss(

self)

def add_supervised_loss(self):
    if self.args.classloss:
        print("[discriminator] Class loss is on.  Semi-supervised learning mode activated.")
        supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss)
        self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add')
        supervised_loss.create()
        #EWW
    else:
        print("[discriminator] Class loss is off.  Unsupervised learning mode activated.")

def build(

self)

def build(self):
    save_file = self.args.config+".pbgraph"
    build_file = os.path.expanduser("builds/"+save_file)
    self.create_path(build_file)
    tf.train.write_graph(self.gan.session.graph, 'builds', save_file)
    print("Saved generator to ", build_file)

def check_stdin(

self)

def check_stdin(self):
    try:
        input = sys.stdin.read()
        if input[0]=="y":
            return
        from IPython import embed
        # Misc code
        embed()
    except:
        return

def create_path(

self, filename)

def create_path(self, filename):
    return os.makedirs(os.path.expanduser(os.path.dirname(filename)), exist_ok=True)

def new(

self)

def new(self):
    template = self.args.directory + '.json'
    print("[hypergan] Creating new configuration file '"+template+"' based off of '"+self.config_name+".json'")
    if os.path.isfile(template):
        raise ValidationException("File exists: " + template)
    source_configuration = Configuration.find(self.config_name+".json")
    shutil.copyfile(source_configuration, template)
    return

def run(

self)

def run(self):
    if self.method == 'train':
        self.gan.create()
        self.add_supervised_loss()
        self.gan.session.run(tf.global_variables_initializer())
        if not self.gan.load(self.save_file):
            print("Initializing new model")
        else:
            print("Model loaded")
        tf.train.start_queue_runners(sess=self.gan.session)
        self.train()
        tf.reset_default_graph()
        self.gan.session.close()
    elif self.method == 'build':
        self.gan.create()
        if not self.gan.load(self.save_file):
            raise "Could not load model: "+ save_file
        else:
            print("Model loaded")
        self.build()
        tf.reset_default_graph()
        self.gan.session.close()
    elif self.method == 'new':
        self.new()
    elif self.method == 'sample':
        self.gan.create()
        self.add_supervised_loss()
        if not self.gan.load(self.save_file):
            print("Initializing new model")
        else:
            print("Model loaded")
        tf.train.start_queue_runners(sess=self.gan.session)
        self.sample_forever()
        tf.reset_default_graph()
        self.gan.session.close()

def sample(

self, sample_file)

Samples to a file. Useful for visualizing the learning process.

Use with:

 ffmpeg -i samples/grid-%06d.png -vcodec libx264 -crf 22 -threads 0 grid1-7.mp4

to create a video of the learning process.

def sample(self, sample_file):
    """ Samples to a file.  Useful for visualizing the learning process.
    Use with:
         ffmpeg -i samples/grid-%06d.png -vcodec libx264 -crf 22 -threads 0 grid1-7.mp4
    to create a video of the learning process.
    """
    sample_list = self.sampler.sample(sample_file, self.args.save_samples)
    return sample_list

def sample_forever(

self)

def sample_forever(self):
    while True:
        sample_file="samples/%06d.png" % (self.samples)
        self.create_path(sample_file)
        self.sample(sample_file)
        self.samples += 1
        print("Sample", self.samples)
        sleep(0.2)

def sampler_for(

name)

def sampler_for(name):
    samplers = {
            'static_batch': StaticBatchSampler,
            'random_walk': RandomWalkSampler,
            'alphagan_random_walk': AlphaganRandomWalkSampler,
            'batch': BatchSampler,
            'grid': GridSampler,
            'began': BeganSampler,
            'autoencode': AutoencodeSampler,
            'aligned': AlignedSampler
    }
    if name in samplers:
        return samplers[name]
    else:
        print("[hypergan] No sampler found for ", name, ".  Defaulting to StaticBatch")
        return StaticBatchSampler

def serve(

self, gan)

def serve(self, gan):
    return gan_server(self.gan.session, config)

def step(

self)

def step(self):
    self.gan.step()
    if(self.steps % self.sample_every == 0):
        sample_file="samples/%06d.png" % (self.samples)
        self.create_path(sample_file)
        sample_list = self.sample(sample_file)
        if self.args.use_hc_io:
            self.gan.config['model'] = self.args.config
            hc.io.sample(self.gan.config, sample_list)
        self.samples += 1
    self.steps+=1

def train(

self)

def train(self):
    i=0
    if(self.args.ipython):
        fd = sys.stdin.fileno()
        fl = fcntl.fcntl(fd, fcntl.F_GETFL)
        fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
    while(i < self.total_steps or self.total_steps == -1):
        i+=1
        start_time = time.time()
        self.step()
        if (self.args.save_every != None and
            self.args.save_every != -1 and
            self.args.save_every > 0 and
            i % self.args.save_every == 0):
            print(" |= Saving network")
            self.gan.save(self.save_file)
        if self.args.ipython:
            self.check_stdin()
        end_time = time.time()

def validate(

self)

def validate(self):
    if(self.sampler == None):
        raise ValidationException("No sampler found by the name '"+self.sampler_name+"'")

Instance variables

var args

var config_name

var gan

var method

var sample_every

var sampler

var samples

var steps

var total_steps