aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-20 20:25:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 20:29:28 -0700
commitf10b00558de87020554c9c0512537dab96dba918 (patch)
tree6d10c5e83eb73d5b452b2ce9be1acaf82b5a7e94 /tensorflow/contrib/distribute
parentf283f3ac5d7b6de8cadc9c1cee6886b187319afd (diff)
Make threading.local not an instance member of collective ops because in python3 threading.local cannot be pickled.
PiperOrigin-RevId: 213928766
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 24cb08fb48..9fc1b88955 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -221,9 +221,12 @@ def split_grads_by_size(threshold_size, device_grads):
return small_grads, large_grads
-# threading.Lock() cannot be pickled and therefore cannot be a field of
-# CollectiveKeys.
+# threading.Lock() and threading.local() cannot be pickled and therefore cannot
+# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
+# an instance member of CollectiveKeys since we always create a new thread for
+# each tower.
_lock = threading.Lock()
+_thread_local = threading.local()
# TODO(yuefengz): use random key starts to avoid reusing keys?
@@ -266,14 +269,12 @@ class CollectiveKeys(object):
# For instance keys without ids
self._instance_key_start = instance_key_start
- self._thread_local = threading.local()
-
def _get_thread_local_object(self):
# We make instance key without key ids thread local so that it will work
# with MirroredStrategy and distribute coordinator.
- if not hasattr(self._thread_local, 'instance_key'):
- self._thread_local.instance_key = self._instance_key_start
- return self._thread_local
+ if not hasattr(_thread_local, 'instance_key'):
+ _thread_local.instance_key = self._instance_key_start
+ return _thread_local
def get_group_key(self, devices):
"""Returns a group key for the set of devices.