aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/freeze_graph.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/python/tools/freeze_graph.py
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/python/tools/freeze_graph.py')
-rw-r--r--tensorflow/python/tools/freeze_graph.py38
1 files changed, 29 insertions, 9 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 0ddf09260b..affa97062a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def,
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
- saved_model_tags=None):
+ saved_model_tags=None,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
@@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def,
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
- saver = saver_lib.Saver(saver_def=input_saver_def)
+ saver = saver_lib.Saver(
+ saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
@@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(var_list=var_list)
+ saver = saver_lib.Saver(
+ var_list=var_list, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
@@ -217,7 +220,8 @@ def freeze_graph(input_graph,
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
- saved_model_tags=tag_constants.SERVING):
+ saved_model_tags=tag_constants.SERVING,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
@@ -233,10 +237,21 @@ def freeze_graph(input_graph,
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
- input_graph_def, input_saver_def, input_checkpoint, output_node_names,
- restore_op_name, filename_tensor_name, output_graph, clear_devices,
- initializer_nodes, variable_names_whitelist, variable_names_blacklist,
- input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
+ input_graph_def,
+ input_saver_def,
+ input_checkpoint,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph,
+ clear_devices,
+ initializer_nodes,
+ variable_names_whitelist,
+ variable_names_blacklist,
+ input_meta_graph_def,
+ input_saved_model_dir,
+ saved_model_tags.split(","),
+ checkpoint_version=checkpoint_version)
def main(unused_args):
@@ -246,7 +261,7 @@ def main(unused_args):
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags)
+ FLAGS.saved_model_tags, FLAGS.checkpoint_version)
if __name__ == "__main__":
@@ -268,6 +283,11 @@ if __name__ == "__main__":
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
+ "--checkpoint_version",
+ type=int,
+ default=saver_pb2.SaverDef.V2,
+ help="Tensorflow variable file format")
+ parser.add_argument(
"--output_graph",
type=str,
default="",