aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/device_setter.py2
-rw-r--r--tensorflow/python/training/device_setter_test.py7
2 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index 7f403f4927..85ee10379a 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
if ps_ops is None:
# TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
# placed in the parameter server.
- ps_ops = ["Variable", "VariableV2"]
+ ps_ops = ["Variable", "VariableV2", "VarHandleOp"]
if not merge_devices:
logging.warning(
diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py
index e05f0f6a1c..bc29e0d21c 100644
--- a/tensorflow/python/training/device_setter_test.py
+++ b/tensorflow/python/training/device_setter_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
@@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase):
self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
self.assertDeviceEqual("/job:worker/cpu:0", a.device)
+ def testResource(self):
+ with ops.device(
+ device_setter.replica_device_setter(cluster=self._cluster_spec)):
+ v = resource_variable_ops.ResourceVariable([1, 2])
+ self.assertDeviceEqual("/job:ps/task:0", v.device)
+
def testPS2TasksWithClusterSpecClass(self):
with ops.device(
device_setter.replica_device_setter(cluster=self._cluster_spec)):