aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-29 13:35:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-29 13:44:43 -0800
commitac11a9d52d6e4188f30333e1f6f1e40dfc939436 (patch)
treef821d4adc6108978cf1eb28b90203b8fcf03b1e9 /tensorflow/contrib/graph_editor
parentd8cdf23c5a75efd043ec46ce8dbb66669acaa128 (diff)
Create a copy of the collections dict's items before doing iteration that could add to that dict.
Change: 143204319
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/transform.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 11c19ccc22..6fb347c834 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -93,8 +93,7 @@ def assign_renamed_collections_handler(info, elem, elem_):
elem_: the transformed element
"""
# TODO(fkp): handle known special cases
- for name, collection in iteritems(
- elem.graph._collections): # pylint: disable=protected-access
+ for name, collection in iteritems(info.collections):
if elem not in collection:
continue
collection_name_ = info.transformer.new_name(name)
@@ -238,6 +237,9 @@ class Transformer(object):
self.scope_ = dst_scope
self.transformed_ops = {}
self.transformed_ts = {}
+ self.collections = dict((key, self.graph.get_collection(key))
+ for key in self.graph.get_all_collection_keys())
+
class ResultInfo(object):
""""Contains information about the result of a transform operation."""