Coverage for hypergan/trainers/multi_step_trainer.py : 21%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
|
BaseTrainer.__init__(self, gan, config) self.losses = losses self.var_lists = var_lists self.metrics = metrics or [None for i in self.losses]
gan = self.gan config = self.config losses = self.losses g_lr = config.g_learn_rate d_lr = config.d_learn_rate
optimizers = [] for i, _ in enumerate(losses): loss = losses[i][1] var_list = self.var_lists[i] is_generator = losses[i][0] == 'generator'
if is_generator: optimizer = self.build_optimizer(config, 'g_', config.g_trainer, self.g_lr, var_list, loss) else: optimizer = self.build_optimizer(config, 'd_', config.d_trainer, self.d_lr, var_list, loss) optimizers.append(optimizer)
self.optimizers = optimizers
if config.d_clipped_weights: self.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights)) for d in d_vars] else: self.clip = []
return None
gan = self.gan sess = gan.session config = self.config losses = self.losses metrics = self.metrics
for i, _ in enumerate(losses): loss = losses[i] optimizer = self.optimizers[i] metric = metrics[i] if(metric): metric_values = sess.run([optimizer] + self.output_variables(metric), feed_dict)[1:]
if self.current_step % 100 == 0: print("loss " + str(i) + " "+ loss[0] + " " + self.output_string(metric) % tuple([self.current_step] + metric_values)) else: _ = sess.run(optimizer, feed_dict) |