diff options
Diffstat (limited to 'tensorflow/python/tools/freeze_graph.py')
-rw-r--r-- | tensorflow/python/tools/freeze_graph.py | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index a52f325ddb..e9f1def48c 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as saver_lib -FLAGS = None - def freeze_graph_with_def_protos(input_graph_def, input_saver_def, @@ -256,25 +254,24 @@ def freeze_graph(input_graph, checkpoint_version=checkpoint_version) -def main(unused_args): - if FLAGS.checkpoint_version == 1: +def main(unused_args, flags): + if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1 - elif FLAGS.checkpoint_version == 2: + elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: print("Invalid checkpoint version (must be '1' or '2'): %d" % - FLAGS.checkpoint_version) + flags.checkpoint_version) return -1 - freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, - FLAGS.input_checkpoint, FLAGS.output_node_names, - FLAGS.restore_op_name, FLAGS.filename_tensor_name, - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, - FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, - FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags, checkpoint_version) - + freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, + flags.input_checkpoint, flags.output_node_names, + flags.restore_op_name, flags.filename_tensor_name, + flags.output_graph, flags.clear_devices, flags.initializer_nodes, + flags.variable_names_whitelist, flags.variable_names_blacklist, + flags.input_meta_graph, flags.input_saved_model_dir, + flags.saved_model_tags, checkpoint_version) -if __name__ == "__main__": +def run_main(): parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( @@ -376,5 +373,10 @@ if __name__ == "__main__": separated by \',\'. For tag-set contains multiple tags, all tags \ must be passed in.\ """) - FLAGS, unparsed = parser.parse_known_args() - app.run(main=main, argv=[sys.argv[0]] + unparsed) + flags, unparsed = parser.parse_known_args() + + my_main = lambda unused_args: main(unused_args, flags) + app.run(main=my_main, argv=[sys.argv[0]] + unparsed) + +if __name__ == '__main__': + run_main() |