diff options
author | JxKing <jinxin900924@gmail.com> | 2018-05-31 19:24:19 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-31 19:24:19 +0800 |
commit | 16c42f0d4826b12a5359281997ee3f8e27fd5a87 (patch) | |
tree | 94e1f620ac9c93406a59c8860503e7e7f47976c3 /tensorflow/contrib/opt | |
parent | 6c279ad4055a2d568977a02a2eb3b1303117ac15 (diff) |
fix "workers share local variables" error
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/elastic_average_optimizer.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 545c3477bf..209c4611f3 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -102,7 +102,12 @@ class ElasticAverageCustomGetter(object): else: kwargs['trainable'] = trainable kwargs['collections'] = collections - return getter(name, *args, **kwargs) + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) + class ElasticAverageOptimizer(optimizer.Optimizer): |