aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/graph_util.py
diff options
context:
space:
mode:
authorGravatar Jianmin Chen <goog.jmchen@gmail.com>2016-03-10 14:39:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-10 14:42:18 -0800
commit025c0d21a6689f081082b1a51f8812e56b07af77 (patch)
tree03950b3bf2decc3b49c6e0735f12015f2bd42b7c /tensorflow/python/client/graph_util.py
parente71090a1190d039b124ecfddd90f901e328647ea (diff)
Improve the loading of the variables for freeze_graph
Change: 116910206
Diffstat (limited to 'tensorflow/python/client/graph_util.py')
-rw-r--r--tensorflow/python/client/graph_util.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py
index 2aaa0deb05..6dedcf56eb 100644
--- a/tensorflow/python/client/graph_util.py
+++ b/tensorflow/python/client/graph_util.py
@@ -250,10 +250,16 @@ def convert_variables_to_constants(sess, input_graph_def, output_node_names):
GraphDef containing a simplified version of the original.
"""
found_variables = {}
+ variable_names = []
+ variable_dict_names = []
for node in input_graph_def.node:
if node.op == "Assign":
variable_name = node.input[0]
- found_variables[variable_name] = sess.run(variable_name + ":0")
+ variable_dict_names.append(variable_name)
+ variable_names.append(variable_name + ":0")
+ returned_variables = sess.run(variable_names)
+ found_variables = dict(zip(variable_dict_names, returned_variables))
+ logging.info("Frozen %d variables." % len(returned_variables))
# This graph only includes the nodes needed to evaluate the output nodes, and
# removes unneeded nodes like those involved in saving and assignment.