aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-12-01 19:44:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 19:47:50 -0800
commit7d82dbb42744a21ff05924e973e57a68465f3347 (patch)
tree219e05f6b75554657073cbc32c1f76e28bb4bc50
parent01f097d789e88c58cfc16d5052e2bb83f6412ef3 (diff)
Fix a replicate_model_fn_test that is dependent on the number of hardware GPUs.
It has been causing failures when run on a machine with 4 GPUs. PiperOrigin-RevId: 177670759
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 662021853d..91e4b9ba7d 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -288,7 +288,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
- self.model_fn, self.optimizer_fn)
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
session.run(variables.global_variables_initializer())