diff options
author | 2016-03-10 14:39:48 -0800 | |
---|---|---|
committer | 2016-03-10 14:42:18 -0800 | |
commit | 025c0d21a6689f081082b1a51f8812e56b07af77 (patch) | |
tree | 03950b3bf2decc3b49c6e0735f12015f2bd42b7c | |
parent | e71090a1190d039b124ecfddd90f901e328647ea (diff) |
Improve the loading of the variables for freeze_graph
Change: 116910206
-rw-r--r-- | tensorflow/python/client/graph_util.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py index 2aaa0deb05..6dedcf56eb 100644 --- a/tensorflow/python/client/graph_util.py +++ b/tensorflow/python/client/graph_util.py @@ -250,10 +250,16 @@ def convert_variables_to_constants(sess, input_graph_def, output_node_names): GraphDef containing a simplified version of the original. """ found_variables = {} + variable_names = [] + variable_dict_names = [] for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] - found_variables[variable_name] = sess.run(variable_name + ":0") + variable_dict_names.append(variable_name) + variable_names.append(variable_name + ":0") + returned_variables = sess.run(variable_names) + found_variables = dict(zip(variable_dict_names, returned_variables)) + logging.info("Frozen %d variables." % len(returned_variables)) # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. |