aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/freeze_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/freeze_graph.py')
-rw-r--r--tensorflow/python/tools/freeze_graph.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index a52f325ddb..e9f1def48c 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as saver_lib
-FLAGS = None
-
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
@@ -256,25 +254,24 @@ def freeze_graph(input_graph,
checkpoint_version=checkpoint_version)
-def main(unused_args):
- if FLAGS.checkpoint_version == 1:
+def main(unused_args, flags):
+ if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
- elif FLAGS.checkpoint_version == 2:
+ elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
- FLAGS.checkpoint_version)
+ flags.checkpoint_version)
return -1
- freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
- FLAGS.input_checkpoint, FLAGS.output_node_names,
- FLAGS.restore_op_name, FLAGS.filename_tensor_name,
- FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
- FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
- FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags, checkpoint_version)
-
+ freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
+ flags.input_checkpoint, flags.output_node_names,
+ flags.restore_op_name, flags.filename_tensor_name,
+ flags.output_graph, flags.clear_devices, flags.initializer_nodes,
+ flags.variable_names_whitelist, flags.variable_names_blacklist,
+ flags.input_meta_graph, flags.input_saved_model_dir,
+ flags.saved_model_tags, checkpoint_version)
-if __name__ == "__main__":
+def run_main():
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
@@ -376,5 +373,10 @@ if __name__ == "__main__":
separated by \',\'. For tag-set contains multiple tags, all tags \
must be passed in.\
""")
- FLAGS, unparsed = parser.parse_known_args()
- app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ flags, unparsed = parser.parse_known_args()
+
+ my_main = lambda unused_args: main(unused_args, flags)
+ app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
+
+if __name__ == '__main__':
+ run_main()