Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

import importlib 

import json 

import numpy as np 

import os 

import sys 

import time 

import uuid 

import copy 

 

from hypergan.discriminators import * 

from hypergan.encoders import * 

from hypergan.generators import * 

from hypergan.inputs import * 

from hypergan.samplers import * 

from hypergan.trainers import * 

 

import hyperchamber as hc 

from hyperchamber import Config 

from hypergan.ops import TensorflowOps 

import tensorflow as tf 

import hypergan as hg 

from hypergan.skip_connections import SkipConnections 

 

from hypergan.gan_component import ValidationException, GANComponent 

from .base_gan import BaseGAN 

 

class StandardGAN(BaseGAN): 

"""  

Standard GANs consist of: 

 

*required to sample* 

 

* encoder 

* generator 

* sampler 

 

*required to train* 

 

* discriminator 

* loss 

* trainer 

""" 

def __init__(self, *args, **kwargs): 

BaseGAN.__init__(self, *args, **kwargs) 

self.discriminator = None 

self.encoder = None 

self.generator = None 

self.loss = None 

self.trainer = None 

self.session = None 

self.skip_connections = SkipConnections() 

 

def required(self): 

return "generator".split() 

 

def create(self): 

BaseGAN.create(self) 

config = self.config 

 

def create_if(obj): 

if(hasattr(obj, 'create')): 

obj.create() 

 

with tf.device(self.device): 

if self.session is None: 

self.session = self.ops.new_session(self.ops_config) 

 

#this is in a specific order 

if self.encoder is None and config.encoder: 

self.encoder = self.create_component(config.encoder) 

create_if(self.encoder) 

if self.generator is None and config.generator: 

self.generator = self.create_component(config.generator) 

create_if(self.generator) 

self.uniform_sample = self.generator.sample 

 

if self.discriminator is None and config.discriminator: 

self.discriminator = self.create_component(config.discriminator) 

self.discriminator.ops.describe("discriminator") 

create_if(self.discriminator) 

if self.loss is None and config.loss: 

self.loss = self.create_component(config.loss) 

create_if(self.loss) 

if self.trainer is None and config.trainer: 

self.trainer = self.create_component(config.trainer) 

create_if(self.trainer) 

 

 

self.session.run(tf.global_variables_initializer()) 

 

def step(self, feed_dict={}): 

if not self.created: 

self.create() 

if self.trainer == None: 

raise ValidationException("gan.trainer is missing. Cannot train.") 

return self.trainer.step(feed_dict)