diff options
author | Mark Daoust <markdaoust@google.com> | 2018-03-13 09:13:07 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-03-13 09:13:07 -0700 |
commit | 95fcee6b4ec9a2e782b70302d2c75e07c326297d (patch) | |
tree | 0501041e041571d6b1ec703cc37435deed2c0438 | |
parent | 1d76d3e7c7eecddee960c20c9896ccc43d7ccd5c (diff) | |
parent | 1db29b831dc66a98442ce7a00204e0128239c1dd (diff) |
Merge pull request #17655 from MarkDaoust/freeze_graph
Fix the script entry point for freeze_graph.
-rw-r--r-- | tensorflow/python/tools/freeze_graph.py | 36 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/setup.py | 2 |
2 files changed, 20 insertions, 18 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index a52f325ddb..e9f1def48c 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as saver_lib -FLAGS = None - def freeze_graph_with_def_protos(input_graph_def, input_saver_def, @@ -256,25 +254,24 @@ def freeze_graph(input_graph, checkpoint_version=checkpoint_version) -def main(unused_args): - if FLAGS.checkpoint_version == 1: +def main(unused_args, flags): + if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1 - elif FLAGS.checkpoint_version == 2: + elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: print("Invalid checkpoint version (must be '1' or '2'): %d" % - FLAGS.checkpoint_version) + flags.checkpoint_version) return -1 - freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, - FLAGS.input_checkpoint, FLAGS.output_node_names, - FLAGS.restore_op_name, FLAGS.filename_tensor_name, - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, - FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, - FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags, checkpoint_version) - + freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, + flags.input_checkpoint, flags.output_node_names, + flags.restore_op_name, flags.filename_tensor_name, + flags.output_graph, flags.clear_devices, flags.initializer_nodes, + flags.variable_names_whitelist, flags.variable_names_blacklist, + flags.input_meta_graph, flags.input_saved_model_dir, + flags.saved_model_tags, checkpoint_version) -if __name__ == "__main__": +def run_main(): parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( @@ -376,5 +373,10 @@ if __name__ == "__main__": separated by \',\'. For tag-set contains multiple tags, all tags \ must be passed in.\ """) - FLAGS, unparsed = parser.parse_known_args() - app.run(main=main, argv=[sys.argv[0]] + unparsed) + flags, unparsed = parser.parse_known_args() + + my_main = lambda unused_args: main(unused_args, flags) + app.run(main=my_main, argv=[sys.argv[0]] + unparsed) + +if __name__ == '__main__': + run_main() diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 815ea8157d..7fdf0d8c17 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -72,7 +72,7 @@ if sys.version_info < (3, 4): # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ - 'freeze_graph = tensorflow.python.tools.freeze_graph:main', + 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', |