Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

import tensorflow as tf 

import numpy as np 

import hyperchamber as hc 

import inspect 

 

from hypergan.trainers.base_trainer import BaseTrainer 

 

TINY = 1e-12 

 

class AlternatingTrainer(BaseTrainer): 

def _create(self): 

gan = self.gan 

config = self.config 

 

d_vars = self.d_vars or gan.discriminator.variables() 

g_vars = self.g_vars or (gan.encoder.variables() + gan.generator.variables()) 

 

loss = self.loss or gan.loss 

d_loss, g_loss = loss.sample 

 

self.d_log = -tf.log(tf.abs(d_loss+TINY)) 

 

g_optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, g_vars, g_loss) 

d_optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, d_vars, d_loss) 

 

self.g_loss = g_loss 

self.d_loss = d_loss 

self.d_optimizer = d_optimizer 

self.g_optimizer = g_optimizer 

 

if config.d_clipped_weights: 

self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights)) for d in d_vars] 

else: 

self.clip = [] 

 

if config.anneal_learning_rate: 

anneal_rate = config.anneal_rate or 0.5 

anneal_lower_bound = config.anneal_lower_bound or 2e-5 

self.anneal = [ 

tf.assign(self.d_lr, tf.maximum(self.d_lr * anneal_rate, anneal_lower_bound)), 

tf.assign(self.g_lr, tf.maximum(self.g_lr * anneal_rate, anneal_lower_bound)) 

] 

 

return g_optimizer, d_optimizer 

 

def _step(self, feed_dict): 

gan = self.gan 

sess = gan.session 

config = self.config 

loss = self.loss or gan.loss 

metrics = loss.metrics 

 

d_loss, g_loss = loss.sample 

 

for i in range(config.d_update_steps or 1): 

sess.run([self.d_optimizer] + self.clip, feed_dict) 

 

metric_values = sess.run([self.g_optimizer] + self.output_variables(metrics), feed_dict)[1:] 

 

if self.current_step % 100 == 0: 

print(self.output_string(metrics) % tuple([self.current_step] + metric_values)) 

 

anneal_every = config.anneal_every or 100000 

if config.anneal_learning_rate and self.current_step > 0 and self.current_step % anneal_every == 0: 

dlr, glr = sess.run(self.anneal) 

print("Lowering the learning rate to d:" + str(dlr) + ", g:" + str(glr))