diff options
author | RJ Ryan <rryan@alum.mit.edu> | 2016-04-25 11:49:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-25 12:52:01 -0700 |
commit | b3470493f49f1232ed51435c607c09b26f87cc57 (patch) | |
tree | 03bae1c657ff29db1957345d2634e96d876cc842 | |
parent | 249664baf82a955708e4e4e22dc080f67fff2fad (diff) |
Expose DeviceSpec class in the Python framework API.
* Rename Device to DeviceSpec to avoid confusion with tf.device.
* Move tensorflow.python.framework.device.from_string helper function to be a staticmethod of DeviceSpec.
This allows Python code to check the validity of device specs, canonicalize them, merge them, etc.
Change: 120736225
-rw-r--r-- | tensorflow/python/client/graph_util.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/device.py | 117 | ||||
-rw-r--r-- | tensorflow/python/framework/device_test.py | 39 | ||||
-rw-r--r-- | tensorflow/python/framework/framework_lib.py | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 4 | ||||
-rw-r--r-- | tensorflow/python/training/device_setter.py | 8 |
8 files changed, 114 insertions, 69 deletions
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py index 969b936393..1a4d916e18 100644 --- a/tensorflow/python/client/graph_util.py +++ b/tensorflow/python/client/graph_util.py @@ -58,7 +58,7 @@ def set_cpu0(device_string): Returns: A device string. """ - parsed_device = pydev.from_string(device_string) + parsed_device = pydev.DeviceSpec.from_string(device_string) parsed_device.device_type = "CPU" parsed_device.device_index = 0 return parsed_device.to_string() diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py index 37557343aa..456df2e6ea 100644 --- a/tensorflow/python/framework/device.py +++ b/tensorflow/python/framework/device.py @@ -21,15 +21,52 @@ from __future__ import print_function import copy -class Device(object): - """Represents a Device.""" +class DeviceSpec(object): + """Represents a (possibly partial) specification for a TensorFlow device. + + `DeviceSpec`s are used throughout TensorFlow to describe where state is stored + and computations occur. Using `DeviceSpec` allows you to parse device spec + strings to verify their validity, merge them or compose them programmatically. + + Example: + ```python + # Place the operations on device "GPU:0" in the "ps" job. + device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) + with tf.device(device_spec): + # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. + my_var = tf.Variable(..., name="my_variable") + squared_var = tf.square(my_var) + ``` + + If a `DeviceSpec` is partially specified, it will be merged with other + `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` + components defined in inner scopes take precedence over those defined in + outer scopes. + + ```python + with tf.device(DeviceSpec(job="train", )): + with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): + # Nodes created here will be assigned to /job:ps/device:GPU:0. + with tf.device(DeviceSpec(device_type="GPU", device_index=1): + # Nodes created here will be assigned to /job:train/device:GPU:1. + ``` + + A `DeviceSpec` consists of 5 components -- each of + which is optionally specified: + + * Job: The job name. + * Replica: The replica index. + * Task: The task index. + * Device type: The device type string (e.g. "CPU" or "GPU"). + * Device index: The device index. + """ def __init__(self, job=None, replica=None, task=None, device_type=None, device_index=None): - """Create a new device object. + """Create a new `DeviceSpec` object. Args: - job: string. Optional device job name. + job: string. Optional job name. replica: int. Optional replica index. task: int. Optional task index. device_type: Optional device type string (e.g. "CPU" or "GPU") @@ -88,7 +125,7 @@ class Device(object): self._task = None def parse_from_string(self, spec): - """Parse a Device name into its components. + """Parse a `DeviceSpec` name into its components. Args: spec: a string of the form @@ -99,7 +136,7 @@ class Device(object): All entries are optional. Returns: - The Device, for convenience. + The `DeviceSpec`. Raises: ValueError: if the spec was not valid. @@ -135,10 +172,10 @@ class Device(object): return self def merge_from(self, dev): - """Merge the properties of "dev" into this Device. + """Merge the properties of "dev" into this `DeviceSpec`. Args: - dev: a Device. + dev: a `DeviceSpec`. """ if dev.job is not None: self.job = dev.job @@ -152,11 +189,11 @@ class Device(object): self.device_index = dev.device_index def to_string(self): - """Return a Device specification string. + """Return a string representation of this `DeviceSpec`. Returns: - a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id> - or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>. + a string of the form + /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. """ dev = "" if self.job is not None: @@ -172,22 +209,28 @@ class Device(object): dev += "/device:%s:%s" % (self.device_type, device_index_string) return dev + @staticmethod + def from_string(spec): + """Construct a `DeviceSpec` from a string. + + Args: + spec: a string of the form + /job:<name>/replica:<id>/task:<id>/device:CPU:<id> + or + /job:<name>/replica:<id>/task:<id>/device:GPU:<id> + as cpu and gpu are mutually exclusive. + All entries are optional. -def from_string(spec): - """Construct a Device from a string. + Returns: + A DeviceSpec. + """ + return DeviceSpec().parse_from_string(spec) - Args: - spec: a string of the form - /job:<name>/replica:<id>/task:<id>/device:CPU:<id> - or - /job:<name>/replica:<id>/task:<id>/device:GPU:<id> - as cpu and gpu are mutually exclusive. - All entries are optional. - Returns: - A Device. - """ - return Device().parse_from_string(spec) +# For backwards compatibility. +# TODO(rjryan): Fix all callers then remove. +Device = DeviceSpec +from_string = DeviceSpec.from_string def check_valid(spec): @@ -199,18 +242,18 @@ def check_valid(spec): Raises: An exception if the spec is invalid. """ - # Construct a device. It will assert a failure if spec is invalid. - from_string(spec) + # Construct a DeviceSpec. It will assert a failure if spec is invalid. + DeviceSpec.from_string(spec) def canonical_name(device): - """Returns a canonical name for the given device or device name.""" + """Returns a canonical name for the given `DeviceSpec` or device name.""" if device is None: return "" - if isinstance(device, Device): + if isinstance(device, DeviceSpec): return device.to_string() else: - device = from_string(device) + device = DeviceSpec.from_string(device) return device.to_string() @@ -220,17 +263,17 @@ def merge_device(spec): This can be used to merge partial specifications of devices. The innermost setting for a device field takes precedence. For example: - with tf.Device(MergeDevice("/device:GPU:0")) + with tf.device(merge_device("/device:GPU:0")) # Nodes created here have device "/device:GPU:0" - with tf.Device(MergeDevice("/job:worker")): + with tf.device(merge_device("/job:worker")): # Nodes created here have device "/job:worker/device:GPU:0" - with tf.Device(MergeDevice("/device:CPU:0")): + with tf.device(merge_device("/device:CPU:0")): # Nodes created here have device "/job:worker/device:CPU:0" - with tf.Device(MergeDevice("/job:ps")): + with tf.device(merge_device("/job:ps")): # Nodes created here have device "/job:ps/device:CPU:0" Args: - spec: A device or a device spec string (partially) describing the + spec: A `DeviceSpec` or a device spec string (partially) describing the device that should be used for all nodes created in the scope of the returned device function's with block. @@ -240,10 +283,10 @@ def merge_device(spec): Raises: ValueError: if the spec was not valid. """ - if not isinstance(spec, Device): - spec = from_string(spec or "") + if not isinstance(spec, DeviceSpec): + spec = DeviceSpec.from_string(spec or "") def _device_function(node_def): - current_device = from_string(node_def.device or "") + current_device = DeviceSpec.from_string(node_def.device or "") copy_spec = copy.copy(spec) copy_spec.merge_from(current_device) # current_device takes precedence. return copy_spec diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py index cb72770291..73cf60679f 100644 --- a/tensorflow/python/framework/device_test.py +++ b/tensorflow/python/framework/device_test.py @@ -26,14 +26,14 @@ from tensorflow.python.platform import googletest class DeviceTest(test_util.TensorFlowTestCase): def testEmpty(self): - d = device.Device() + d = device.DeviceSpec() self.assertEquals("", d.to_string()) d.parse_from_string("") self.assertEquals("", d.to_string()) def testConstructor(self): - d = device.Device(job="j", replica=0, task=1, - device_type="CPU", device_index=2) + d = device.DeviceSpec(job="j", replica=0, task=1, + device_type="CPU", device_index=2) self.assertEqual("j", d.job) self.assertEqual(0, d.replica) self.assertEqual(1, d.task) @@ -41,11 +41,11 @@ class DeviceTest(test_util.TensorFlowTestCase): self.assertEqual(2, d.device_index) self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string()) - d = device.Device(device_type="GPU", device_index=0) + d = device.DeviceSpec(device_type="GPU", device_index=0) self.assertEquals("/device:GPU:0", d.to_string()) def testto_string(self): - d = device.Device() + d = device.DeviceSpec() d.job = "foo" self.assertEquals("/job:foo", d.to_string()) d.task = 3 @@ -68,11 +68,11 @@ class DeviceTest(test_util.TensorFlowTestCase): self.assertEquals("/job:foo/replica:12", d.to_string()) # Test wildcard - d = device.Device(job="foo", replica=12, task=3, device_type="GPU") + d = device.DeviceSpec(job="foo", replica=12, task=3, device_type="GPU") self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) def testParse(self): - d = device.Device() + d = device.DeviceSpec() d.parse_from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) d.parse_from_string("/replica:1/task:0/cpu:0") @@ -86,33 +86,34 @@ class DeviceTest(test_util.TensorFlowTestCase): self.assertTrue("Cannot specify multiple device" in str(e.exception)) def testFromString(self): - d = device.from_string("/job:foo/replica:0") + d = device.DeviceSpec.from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) with self.assertRaises(Exception) as e: - d = device.from_string("/job:muu/gpu:2/cpu:0") + d = device.DeviceSpec.from_string("/job:muu/gpu:2/cpu:0") self.assertTrue("Cannot specify multiple device" in str(e.exception)) - d = device.from_string("/job:foo/replica:0/task:3/cpu:*") + d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/cpu:*") self.assertEquals(None, d.device_index) - d = device.from_string("/job:foo/replica:0/task:3/gpu:7") + d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/gpu:7") self.assertEquals(7, d.device_index) - d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7") + d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/device:GPU:7") self.assertEquals(7, d.device_index) def testMerge(self): - d = device.from_string("/job:foo/replica:0") + d = device.DeviceSpec.from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) - d.merge_from(device.from_string("/task:1/gpu:2")) + d.merge_from(device.DeviceSpec.from_string("/task:1/gpu:2")) self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) - d = device.Device() - d.merge_from(device.from_string("/task:1/cpu:0")) + d = device.DeviceSpec() + d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0")) self.assertEquals("/task:1/device:CPU:0", d.to_string()) - d.merge_from(device.from_string("/job:boo/gpu:0")) + d.merge_from(device.DeviceSpec.from_string("/job:boo/gpu:0")) self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string()) - d.merge_from(device.from_string("/job:muu/cpu:2")) + d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2")) self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string()) - d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2")) + d.merge_from(device.DeviceSpec.from_string( + "/job:muu/device:MyFunnyDevice:2")) self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) def testCanonicalName(self): diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index f74749ad57..4f83e28fbb 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -67,6 +67,7 @@ from __future__ import division from __future__ import print_function # Classes used when building a Graph. +from tensorflow.python.framework.device import DeviceSpec from tensorflow.python.framework.ops import Graph from tensorflow.python.framework.ops import Operation from tensorflow.python.framework.ops import Tensor diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f3ba9687e6..effa64e77c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -987,7 +987,7 @@ SparseTensorValue = collections.namedtuple("SparseTensorValue", def _device_string(dev_spec): - if isinstance(dev_spec, pydev.Device): + if isinstance(dev_spec, pydev.DeviceSpec): return dev_spec.to_string() else: return dev_spec diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 3a51e37e6c..f4e3be800c 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -114,7 +114,7 @@ class NodeDefConstructorTest(test_util.TensorFlowTestCase): nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", nodedef) - nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j")) + nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) @@ -223,7 +223,8 @@ class OperationTest(test_util.TensorFlowTestCase): "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) - op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0)) + op._set_device(pydev.DeviceSpec(job="muu", device_type="CPU", + device_index=0)) self.assertProtoEquals( "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) @@ -526,9 +527,8 @@ class DeviceTest(test_util.TensorFlowTestCase): def testDeviceFull(self): g = ops.Graph() - with g.device(pydev.Device(job="worker", replica=2, task=0, - device_type="CPU", - device_index=3)): + with g.device(pydev.DeviceSpec(job="worker", replica=2, task=0, + device_type="CPU", device_index=3)): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 624012f4de..15c342f38e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -548,8 +548,8 @@ class TensorFlowTestCase(googletest.TestCase): """Asserts that the two given devices are the same. Args: - device1: A string device name or TensorFlow `Device` object. - device2: A string device name or TensorFlow `Device` object. + device1: A string device name or TensorFlow `DeviceSpec` object. + device2: A string device name or TensorFlow `DeviceSpec` object. """ device1 = pydev.canonical_name(device1) device2 = pydev.canonical_name(device2) diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index ce11d6a90e..0accacf8a4 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -73,14 +73,14 @@ class _ReplicaDeviceChooser(object): """ if not self._merge_devices and op.device: return op.device - current_device = pydev.from_string(op.device or "") - spec = pydev.Device() + current_device = pydev.DeviceSpec.from_string(op.device or "") + spec = pydev.DeviceSpec() if self._ps_tasks and self._ps_device: node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def if node_def.op in self._ps_ops: device_string = "%s/task:%d" % (self._ps_device, self._next_task()) if self._merge_devices: - spec = pydev.from_string(device_string) + spec = pydev.DeviceSpec.from_string(device_string) spec.merge_from(current_device) return spec.to_string() else: @@ -88,7 +88,7 @@ class _ReplicaDeviceChooser(object): if self._worker_device: if not self._merge_devices: return self._worker_device - spec = pydev.from_string(self._worker_device) + spec = pydev.DeviceSpec.from_string(self._worker_device) if not self._merge_devices: return "" |