diff options
-rw-r--r-- | tensorflow/python/training/device_setter.py | 2 | ||||
-rw-r--r-- | tensorflow/python/training/device_setter_test.py | 7 |
2 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 7f403f4927..85ee10379a 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", if ps_ops is None: # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be # placed in the parameter server. - ps_ops = ["Variable", "VariableV2"] + ps_ops = ["Variable", "VariableV2", "VarHandleOp"] if not merge_devices: logging.warning( diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py index e05f0f6a1c..bc29e0d21c 100644 --- a/tensorflow/python/training/device_setter_test.py +++ b/tensorflow/python/training/device_setter_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_setter @@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase): self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) self.assertDeviceEqual("/job:worker/cpu:0", a.device) + def testResource(self): + with ops.device( + device_setter.replica_device_setter(cluster=self._cluster_spec)): + v = resource_variable_ops.ResourceVariable([1, 2]) + self.assertDeviceEqual("/job:ps/task:0", v.device) + def testPS2TasksWithClusterSpecClass(self): with ops.device( device_setter.replica_device_setter(cluster=self._cluster_spec)): |