aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-10 19:33:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 10:48:48 -0700
commit326ee1883f873695540e7a85f07f52b24e5974a4 (patch)
treeb687070bbe55c4f996a5a96a60f0f70b886eba6a
parent9d528b8dc22567e239c0eb692eedb1d9e6763fd1 (diff)
Changes _ReplicaDeviceChooser to not set a task number on ops with a non-PS job.
PiperOrigin-RevId: 155705170
-rw-r--r--tensorflow/python/training/device_setter.py46
-rw-r--r--tensorflow/python/training/device_setter_test.py44
2 files changed, 67 insertions, 23 deletions
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index 85ee10379a..02155a98d7 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -94,31 +94,31 @@ class _ReplicaDeviceChooser(object):
Returns:
The device to use for the `Operation`.
"""
+ # If we don't return early here, either merge_devices is True, or op.device
+ # is empty (in which case merging is a no-op). So we can always merge below.
if not self._merge_devices and op.device:
return op.device
+
current_device = pydev.DeviceSpec.from_string(op.device or "")
- spec = pydev.DeviceSpec()
- if self._ps_tasks and self._ps_device:
- node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
- if node_def.op in self._ps_ops:
- device_string = "%s/task:%d" % (
- self._ps_device, self._ps_strategy(op))
- if self._merge_devices:
- spec = pydev.DeviceSpec.from_string(device_string)
- spec.merge_from(current_device)
- return spec.to_string()
- else:
- return device_string
- if self._worker_device:
- if not self._merge_devices:
- return self._worker_device
- spec = pydev.DeviceSpec.from_string(self._worker_device)
-
- if not self._merge_devices:
- return ""
-
- spec.merge_from(current_device)
- return spec.to_string()
+
+ # The ps_device will be used for specified ops (ps_ops) whenever it is
+ # present and ps_tasks is non-zero. However, its task number will only be
+ # set (using ps_strategy) if there is a job field in ps_device that won't be
+ # changed by the job field (if present) in current_device.
+ node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
+ if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
+ ps_device = pydev.DeviceSpec.from_string(self._ps_device)
+
+ current_job, ps_job = current_device.job, ps_device.job
+ if ps_job and (not current_job or current_job == ps_job):
+ ps_device.task = self._ps_strategy(op)
+
+ ps_device.merge_from(current_device)
+ return ps_device.to_string()
+
+ worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
+ worker_device.merge_from(current_device)
+ return worker_device.to_string()
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
@@ -186,7 +186,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
cluster_spec = cluster.as_dict()
else:
cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
- # Get ps_job_name from ps_device by striping "/job:".
+ # Get ps_job_name from ps_device by stripping "/job:".
ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
return None
diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py
index bc29e0d21c..85b75502ab 100644
--- a/tensorflow/python/training/device_setter_test.py
+++ b/tensorflow/python/training/device_setter_test.py
@@ -65,6 +65,50 @@ class DeviceSetterTest(test.TestCase):
self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
self.assertDeviceEqual("/job:worker", a.device)
+ def testPS2TasksPinVariableToJob(self):
+ with ops.device(
+ device_setter.replica_device_setter(cluster=self._cluster_spec)):
+ v = variables.Variable([1, 2])
+ with ops.device("/job:moon"):
+ w = variables.Variable([2, 1])
+ with ops.device("/job:ps"): # Explicit PS job will get task set.
+ x = variables.Variable([0, 1])
+ a = v + w + x
+ self.assertDeviceEqual("/job:ps/task:0", v.device)
+ self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
+ self.assertDeviceEqual("/job:moon", w.device)
+ self.assertDeviceEqual("/job:moon", w.initializer.device)
+ self.assertDeviceEqual("/job:ps/task:1", x.device)
+ self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
+ self.assertDeviceEqual("/job:worker", a.device)
+
+ def testPS2TasksUseCpuForPS(self):
+ with ops.device(
+ device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
+ v = variables.Variable([1, 2])
+ with ops.device("/job:moon"):
+ w = variables.Variable([2, 1])
+ a = v + w
+ self.assertDeviceEqual("/cpu:0", v.device)
+ self.assertDeviceEqual("/cpu:0", v.initializer.device)
+ self.assertDeviceEqual("/job:moon/cpu:0", w.device)
+ self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device)
+ self.assertDeviceEqual("/job:worker", a.device)
+
+ def testPS2TasksNoMerging(self):
+ with ops.device(
+ device_setter.replica_device_setter(
+ cluster=self._cluster_spec, merge_devices=False)):
+ v = variables.Variable([1, 2])
+ with ops.device("/job:ps"): # Won't assign task when merge_devices=False.
+ w = variables.Variable([2, 1])
+ a = v + w
+ self.assertDeviceEqual("/job:ps/task:0", v.device)
+ self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
+ self.assertDeviceEqual("/job:ps", w.device)
+ self.assertDeviceEqual("/job:ps", w.initializer.device)
+ self.assertDeviceEqual("/job:worker", a.device)
+
def testPS2TasksWithClusterSpecDict(self):
with ops.device(
device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict(