diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-12-01 19:44:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-01 19:47:50 -0800 |
commit | 7d82dbb42744a21ff05924e973e57a68465f3347 (patch) | |
tree | 219e05f6b75554657073cbc32c1f76e28bb4bc50 | |
parent | 01f097d789e88c58cfc16d5052e2bb83f6412ef3 (diff) |
Fix a replicate_model_fn_test that is dependent on the number of hardware GPUs.
It has been causing failures when run on a machine with 4 GPUs.
PiperOrigin-RevId: 177670759
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 662021853d..91e4b9ba7d 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -288,7 +288,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn) + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) |