aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2018-03-13 09:13:07 -0700
committerGravatar GitHub <noreply@github.com>2018-03-13 09:13:07 -0700
commit95fcee6b4ec9a2e782b70302d2c75e07c326297d (patch)
tree0501041e041571d6b1ec703cc37435deed2c0438
parent1d76d3e7c7eecddee960c20c9896ccc43d7ccd5c (diff)
parent1db29b831dc66a98442ce7a00204e0128239c1dd (diff)
Merge pull request #17655 from MarkDaoust/freeze_graph
Fix the script entry point for freeze_graph.
-rw-r--r--tensorflow/python/tools/freeze_graph.py36
-rw-r--r--tensorflow/tools/pip_package/setup.py2
2 files changed, 20 insertions, 18 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index a52f325ddb..e9f1def48c 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as saver_lib
-FLAGS = None
-
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
@@ -256,25 +254,24 @@ def freeze_graph(input_graph,
checkpoint_version=checkpoint_version)
-def main(unused_args):
- if FLAGS.checkpoint_version == 1:
+def main(unused_args, flags):
+ if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
- elif FLAGS.checkpoint_version == 2:
+ elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
- FLAGS.checkpoint_version)
+ flags.checkpoint_version)
return -1
- freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
- FLAGS.input_checkpoint, FLAGS.output_node_names,
- FLAGS.restore_op_name, FLAGS.filename_tensor_name,
- FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
- FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
- FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags, checkpoint_version)
-
+ freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
+ flags.input_checkpoint, flags.output_node_names,
+ flags.restore_op_name, flags.filename_tensor_name,
+ flags.output_graph, flags.clear_devices, flags.initializer_nodes,
+ flags.variable_names_whitelist, flags.variable_names_blacklist,
+ flags.input_meta_graph, flags.input_saved_model_dir,
+ flags.saved_model_tags, checkpoint_version)
-if __name__ == "__main__":
+def run_main():
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
@@ -376,5 +373,10 @@ if __name__ == "__main__":
separated by \',\'. For tag-set contains multiple tags, all tags \
must be passed in.\
""")
- FLAGS, unparsed = parser.parse_known_args()
- app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ flags, unparsed = parser.parse_known_args()
+
+ my_main = lambda unused_args: main(unused_args, flags)
+ app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
+
+if __name__ == '__main__':
+ run_main()
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 815ea8157d..7fdf0d8c17 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -72,7 +72,7 @@ if sys.version_info < (3, 4):
# pylint: disable=line-too-long
CONSOLE_SCRIPTS = [
- 'freeze_graph = tensorflow.python.tools.freeze_graph:main',
+ 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main',
'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main',
'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main',
'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',