Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

import hyperchamber as hc 

import inspect 

import itertools 

import types 

 

 

class ValidationException(Exception): 

""" 

GAN components validate their configurations before creation.  

 

`ValidationException` occcurs if they fail. 

""" 

pass 

 

class GANComponent: 

""" 

GANComponents are pluggable pieces within a GAN. 

 

GAN objects are also GANComponents. 

""" 

def __init__(self, gan, config): 

""" 

Initializes a gan component based on a `gan` and a `config` dictionary. 

 

Different components require different config variables.  

 

A `ValidationException` is raised if the GAN component configuration fails to validate. 

""" 

self.gan = gan 

self.config = hc.Config(config) 

errors = self.validate() 

if errors != []: 

raise ValidationException(self.__class__.__name__+": " +"\n".join(errors)) 

self.create_ops(config) 

 

def create_ops(self, config): 

""" 

Create the ops object as `self.ops`. Also looks up config 

""" 

if self.gan is None: 

return 

if self.gan.ops_backend is None: 

return 

self.ops = self.gan.ops_backend(config=self.config, device=self.gan.device) 

self.config = self.gan.ops.lookup(config) 

 

def required(self): 

""" 

Return a list of required config strings and a `ValidationException` will be thrown if any are missing. 

 

Example:  

```python 

class MyComponent(GANComponent): 

def required(self): 

"learn rate is required" 

["learn_rate"] 

``` 

""" 

return [] 

 

def validate(self): 

""" 

Validates a GANComponent. Return an array of error messages. Empty array `[]` means success. 

""" 

errors = [] 

required = self.required() 

for argument in required: 

if(self.config.__getattr__(argument) == None): 

errors.append("`"+argument+"` required") 

 

if(self.gan is None): 

errors.append("GANComponent constructed without GAN") 

return errors 

 

def weights(self): 

""" 

The weights of the GAN component. 

""" 

return self.ops.weights 

 

def biases(self): 

""" 

Biases of the GAN component. 

""" 

return self.ops.biases 

 

def variables(self): 

""" 

All variables associated with this component. 

""" 

return self.ops.variables() 

 

def split_batch(self, net, count=2): 

"""  

Discriminators return stacked results (on axis 0).  

 

This splits the results. Returns [d_real, d_fake] 

""" 

ops = self.ops or self.gan.ops 

s = ops.shape(net) 

bs = s[0] 

nets = [] 

net = ops.reshape(net, [bs, -1]) 

start = [0 for x in ops.shape(net)] 

for i in range(count): 

size = [bs//count] + [x for x in ops.shape(net)[1:]] 

nets.append(ops.slice(net, start, size)) 

start[0] += bs//count 

s[0] = s[0] // count 

nets = [ops.reshape(net,s) for net in nets] 

return nets 

 

def reuse(self, net): 

self.ops.reuse() 

net = self.build(net) 

self.ops.stop_reuse() 

return net 

 

def layer_regularizer(self, net): 

symbol = self.config.layer_regularizer 

op = self.gan.ops.lookup(symbol) 

if op and isinstance(op, types.FunctionType): 

net = op(self, net) 

return net 

 

def split_by_width_height(self, net): 

elems = [] 

ops = self.gan.ops 

shape = ops.shape(net) 

bs = shape[0] 

height = shape[1] 

width = shape[2] 

for i in range(width): 

for j in range(height): 

elems.append(ops.slice(net, [0, i, j, 0], [bs, 1, 1, -1])) 

 

return elems 

 

def permute(self, nets, k): 

return list(itertools.permutations(nets, k)) 

 

#this is broken 

def fully_connected_from_list(self, nets): 

results = [] 

ops = self.ops 

for net, net2 in nets: 

net = ops.concat([net, net2], axis=3) 

shape = ops.shape(net) 

bs = shape[0] 

net = ops.reshape(net, [bs, -1]) 

features = ops.shape(net)[1] 

net = ops.linear(net, features) 

#net = self.layer_regularizer(net) 

net = ops.lookup('lrelu')(net) 

#net = ops.linear(net, features) 

net = ops.reshape(net, shape) 

results.append(net) 

return results 

 

def relation_layer(self, net): 

ops = self.ops 

 

#hack 

shape = ops.shape(net) 

input_size = shape[1]*shape[2]*shape[3] 

 

netlist = self.split_by_width_height(net) 

permutations = self.permute(netlist, 2) 

permutations = self.fully_connected_from_list(permutations) 

net = ops.concat(permutations, axis=3) 

 

#hack 

bs = ops.shape(net)[0] 

net = ops.reshape(net, [bs, -1]) 

net = ops.linear(net, input_size) 

net = ops.reshape(net, shape) 

 

return net