aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar JxKing <jinxin900924@gmail.com>2018-05-31 19:24:19 +0800
committerGravatar GitHub <noreply@github.com>2018-05-31 19:24:19 +0800
commit16c42f0d4826b12a5359281997ee3f8e27fd5a87 (patch)
tree94e1f620ac9c93406a59c8860503e7e7f47976c3 /tensorflow/contrib/opt
parent6c279ad4055a2d568977a02a2eb3b1303117ac15 (diff)
fix "workers share local variables" error
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 545c3477bf..209c4611f3 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -102,7 +102,12 @@ class ElasticAverageCustomGetter(object):
else:
kwargs['trainable'] = trainable
kwargs['collections'] = collections
- return getter(name, *args, **kwargs)
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
+
class ElasticAverageOptimizer(optimizer.Optimizer):