diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-10 19:33:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-11 10:48:48 -0700 |
commit | 326ee1883f873695540e7a85f07f52b24e5974a4 (patch) | |
tree | b687070bbe55c4f996a5a96a60f0f70b886eba6a | |
parent | 9d528b8dc22567e239c0eb692eedb1d9e6763fd1 (diff) |
Changes _ReplicaDeviceChooser to not set a task number on ops with a non-PS job.
PiperOrigin-RevId: 155705170
-rw-r--r-- | tensorflow/python/training/device_setter.py | 46 | ||||
-rw-r--r-- | tensorflow/python/training/device_setter_test.py | 44 |
2 files changed, 67 insertions, 23 deletions
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 85ee10379a..02155a98d7 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -94,31 +94,31 @@ class _ReplicaDeviceChooser(object): Returns: The device to use for the `Operation`. """ + # If we don't return early here, either merge_devices is True, or op.device + # is empty (in which case merging is a no-op). So we can always merge below. if not self._merge_devices and op.device: return op.device + current_device = pydev.DeviceSpec.from_string(op.device or "") - spec = pydev.DeviceSpec() - if self._ps_tasks and self._ps_device: - node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def - if node_def.op in self._ps_ops: - device_string = "%s/task:%d" % ( - self._ps_device, self._ps_strategy(op)) - if self._merge_devices: - spec = pydev.DeviceSpec.from_string(device_string) - spec.merge_from(current_device) - return spec.to_string() - else: - return device_string - if self._worker_device: - if not self._merge_devices: - return self._worker_device - spec = pydev.DeviceSpec.from_string(self._worker_device) - - if not self._merge_devices: - return "" - - spec.merge_from(current_device) - return spec.to_string() + + # The ps_device will be used for specified ops (ps_ops) whenever it is + # present and ps_tasks is non-zero. However, its task number will only be + # set (using ps_strategy) if there is a job field in ps_device that won't be + # changed by the job field (if present) in current_device. + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops: + ps_device = pydev.DeviceSpec.from_string(self._ps_device) + + current_job, ps_job = current_device.job, ps_device.job + if ps_job and (not current_job or current_job == ps_job): + ps_device.task = self._ps_strategy(op) + + ps_device.merge_from(current_device) + return ps_device.to_string() + + worker_device = pydev.DeviceSpec.from_string(self._worker_device or "") + worker_device.merge_from(current_device) + return worker_device.to_string() def replica_device_setter(ps_tasks=0, ps_device="/job:ps", @@ -186,7 +186,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", cluster_spec = cluster.as_dict() else: cluster_spec = server_lib.ClusterSpec(cluster).as_dict() - # Get ps_job_name from ps_device by striping "/job:". + # Get ps_job_name from ps_device by stripping "/job:". ps_job_name = pydev.DeviceSpec.from_string(ps_device).job if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: return None diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py index bc29e0d21c..85b75502ab 100644 --- a/tensorflow/python/training/device_setter_test.py +++ b/tensorflow/python/training/device_setter_test.py @@ -65,6 +65,50 @@ class DeviceSetterTest(test.TestCase): self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) self.assertDeviceEqual("/job:worker", a.device) + def testPS2TasksPinVariableToJob(self): + with ops.device( + device_setter.replica_device_setter(cluster=self._cluster_spec)): + v = variables.Variable([1, 2]) + with ops.device("/job:moon"): + w = variables.Variable([2, 1]) + with ops.device("/job:ps"): # Explicit PS job will get task set. + x = variables.Variable([0, 1]) + a = v + w + x + self.assertDeviceEqual("/job:ps/task:0", v.device) + self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) + self.assertDeviceEqual("/job:moon", w.device) + self.assertDeviceEqual("/job:moon", w.initializer.device) + self.assertDeviceEqual("/job:ps/task:1", x.device) + self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) + self.assertDeviceEqual("/job:worker", a.device) + + def testPS2TasksUseCpuForPS(self): + with ops.device( + device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): + v = variables.Variable([1, 2]) + with ops.device("/job:moon"): + w = variables.Variable([2, 1]) + a = v + w + self.assertDeviceEqual("/cpu:0", v.device) + self.assertDeviceEqual("/cpu:0", v.initializer.device) + self.assertDeviceEqual("/job:moon/cpu:0", w.device) + self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device) + self.assertDeviceEqual("/job:worker", a.device) + + def testPS2TasksNoMerging(self): + with ops.device( + device_setter.replica_device_setter( + cluster=self._cluster_spec, merge_devices=False)): + v = variables.Variable([1, 2]) + with ops.device("/job:ps"): # Won't assign task when merge_devices=False. + w = variables.Variable([2, 1]) + a = v + w + self.assertDeviceEqual("/job:ps/task:0", v.device) + self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) + self.assertDeviceEqual("/job:ps", w.device) + self.assertDeviceEqual("/job:ps", w.initializer.device) + self.assertDeviceEqual("/job:worker", a.device) + def testPS2TasksWithClusterSpecDict(self): with ops.device( device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict( |