Coverage for hypergan/cli.py : 52%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
|
self.save_file = self.args.save_file else:
'static_batch': StaticBatchSampler, 'random_walk': RandomWalkSampler, 'alphagan_random_walk': AlphaganRandomWalkSampler, 'batch': BatchSampler, 'grid': GridSampler, 'began': BeganSampler, 'autoencode': AutoencodeSampler, 'debug': DebugSampler, 'aligned': AlignedSampler } return samplers[name] else:
""" Samples to a file. Useful for visualizing the learning process.
Use with:
ffmpeg -i samples/grid-%06d.png -vcodec libx264 -crf 22 -threads 0 grid1-7.mp4
to create a video of the learning process. """
raise ValidationException("No sampler found by the name '"+self.sampler_name+"'")
self.gan.config['model'] = self.args.config hc.io.sample(self.gan.config, sample_list)
save_file_text = self.args.config+".pbtxt" build_file = os.path.expanduser("builds/"+save_file_text) self.create_path(build_file) tf.train.write_graph(self.gan.session.graph, 'builds', save_file_text) inputs = [x.name.split(":")[0] for x in self.gan.input_nodes()] outputs = [x.name.split(":")[0] for x in self.gan.output_nodes()] print("___") print(inputs, outputs) tf.reset_default_graph() self.gan.session.close() [print("Input: ", x) for x in self.gan.input_nodes()] [print("Output: ", y) for y in self.gan.output_nodes()]
pbtxt_path = "builds/"+self.args.config+'.pbtxt' checkpoint_path = "saves/"+self.args.config+'/model.ckpt' input_saver_def_path = "" input_binary = False output_node_names = ",".join(outputs) restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" output_frozen_graph_name = 'builds/frozen_'+self.args.config+'.pb' output_optimized_graph_name = 'builds/optimized_'+self.args.config+'.pb' clear_devices = True
freeze_graph.freeze_graph(pbtxt_path, input_saver_def_path, input_binary, checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_frozen_graph_name, clear_devices, "")
input_graph_def = tf.GraphDef() with tf.gfile.Open(output_frozen_graph_name, "rb") as f: data = f.read() input_graph_def.ParseFromString(data)
print("GRAPH INPUTS", inputs, "OUTPUTS", outputs) output_graph_def = optimize_for_inference_lib.optimize_for_inference( input_graph_def, inputs, # an array of the input node(s) outputs, # an array of output nodes tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "wb") f.write(output_graph_def.SerializeToString()) f.flush() f.close()
print("Saved generator to ", output_optimized_graph_name)
print("Testing loading ", output_optimized_graph_name) with tf.gfile.FastGFile(output_optimized_graph_name, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as sess: for input in inputs: print("Input: ", input, sess.graph.get_tensor_by_name(input+":0")) for output in outputs: print("Output: ", output, sess.graph.get_tensor_by_name(output+":0"))
return gan_server(self.gan.session, config)
while True: sample_file="samples/%06d.png" % (self.samples) self.create_path(sample_file) self.sample(sample_file) self.samples += 1 print("Sample", self.samples) sleep(0.2)
fd = sys.stdin.fileno() fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
self.args.save_every != -1 and self.args.save_every > 0 and i % self.args.save_every == 0): print(" |= Saving network") self.gan.save(self.save_file) self.check_stdin()
try: input = sys.stdin.read() if input[0]=="y": return from IPython import embed # Misc code embed()
except: return
#EWW else: print("[discriminator] Class loss is off. Unsupervised learning mode activated.")
self.gan.create() self.add_supervised_loss() self.gan.session.run(tf.global_variables_initializer())
if not self.gan.load(self.save_file): print("Initializing new model") else: print("Model loaded") tf.train.start_queue_runners(sess=self.gan.session) self.train() tf.reset_default_graph() self.gan.session.close() self.gan.create() if not self.gan.load(self.save_file): raise "Could not load model: "+ save_file else: print("Model loaded") self.build() self.new() self.gan.create() self.add_supervised_loss() if not self.gan.load(self.save_file): print("Initializing new model") else: print("Model loaded")
tf.train.start_queue_runners(sess=self.gan.session) self.sample_forever() tf.reset_default_graph() self.gan.session.close() |