diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-09-20 20:25:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 20:29:28 -0700 |
commit | f10b00558de87020554c9c0512537dab96dba918 (patch) | |
tree | 6d10c5e83eb73d5b452b2ce9be1acaf82b5a7e94 /tensorflow/contrib/distribute | |
parent | f283f3ac5d7b6de8cadc9c1cee6886b187319afd (diff) |
Make threading.local not an instance member of collective ops because in python3 threading.local cannot be pickled.
PiperOrigin-RevId: 213928766
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/cross_tower_utils.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 24cb08fb48..9fc1b88955 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -221,9 +221,12 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads -# threading.Lock() cannot be pickled and therefore cannot be a field of -# CollectiveKeys. +# threading.Lock() and threading.local() cannot be pickled and therefore cannot +# be a field of CollectiveKeys. Right now _thread_local is not necessary to be +# an instance member of CollectiveKeys since we always create a new thread for +# each tower. _lock = threading.Lock() +_thread_local = threading.local() # TODO(yuefengz): use random key starts to avoid reusing keys? @@ -266,14 +269,12 @@ class CollectiveKeys(object): # For instance keys without ids self._instance_key_start = instance_key_start - self._thread_local = threading.local() - def _get_thread_local_object(self): # We make instance key without key ids thread local so that it will work # with MirroredStrategy and distribute coordinator. - if not hasattr(self._thread_local, 'instance_key'): - self._thread_local.instance_key = self._instance_key_start - return self._thread_local + if not hasattr(_thread_local, 'instance_key'): + _thread_local.instance_key = self._instance_key_start + return _thread_local def get_group_key(self, devices): """Returns a group key for the set of devices. |