aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-02-05 17:24:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 17:28:20 -0800
commit78af5c7e2cf9e7a09acfceb368d6bd3818405353 (patch)
treed06c9aa7220cf10fdd88a7833de77364216a0424
parent238bae4426729a0af555aecfda5e2de123443ccd (diff)
Correctly treat "devices=/gpu:0" argument of replicate_model_fn.
At the moment if "devices=/GPU:0" are specified by the user, then variables are going to be placed on the GPU. However, if "devices=/gpu:0" are given, then they are going to be placed on the CPU. Instead, the latter case should be equivalent to the former case. PiperOrigin-RevId: 184612823
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py28
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index c9153c9352..dfae034afc 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -195,7 +195,7 @@ def _replicate_model_fn_with_mode(
if not devices:
devices = _get_local_devices('GPU') or _get_local_devices('CPU')
- is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0]
+ is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper()
consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0'
ps_devices = [consolidation_device]
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 6936f8a131..ab117e61a7 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -451,6 +451,34 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
_ = replicate_model_fn.replicate_model_fn(self.model_fn,
losses.Reduction.NONE)
+ def test_places_on_gpu_with_upper_case_spelling(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session():
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, devices=['/GPU:0'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', c.device)
+
+ def test_places_on_gpu_with_lower_case_spelling(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session():
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', c.device)
+
class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
test_util.TensorFlowTestCase):