aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rryan@alum.mit.edu>2016-04-25 11:49:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-25 12:52:01 -0700
commitb3470493f49f1232ed51435c607c09b26f87cc57 (patch)
tree03bae1c657ff29db1957345d2634e96d876cc842
parent249664baf82a955708e4e4e22dc080f67fff2fad (diff)
Expose DeviceSpec class in the Python framework API.
* Rename Device to DeviceSpec to avoid confusion with tf.device. * Move tensorflow.python.framework.device.from_string helper function to be a staticmethod of DeviceSpec. This allows Python code to check the validity of device specs, canonicalize them, merge them, etc. Change: 120736225
-rw-r--r--tensorflow/python/client/graph_util.py2
-rw-r--r--tensorflow/python/framework/device.py117
-rw-r--r--tensorflow/python/framework/device_test.py39
-rw-r--r--tensorflow/python/framework/framework_lib.py1
-rw-r--r--tensorflow/python/framework/ops.py2
-rw-r--r--tensorflow/python/framework/ops_test.py10
-rw-r--r--tensorflow/python/framework/test_util.py4
-rw-r--r--tensorflow/python/training/device_setter.py8
8 files changed, 114 insertions, 69 deletions
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py
index 969b936393..1a4d916e18 100644
--- a/tensorflow/python/client/graph_util.py
+++ b/tensorflow/python/client/graph_util.py
@@ -58,7 +58,7 @@ def set_cpu0(device_string):
Returns:
A device string.
"""
- parsed_device = pydev.from_string(device_string)
+ parsed_device = pydev.DeviceSpec.from_string(device_string)
parsed_device.device_type = "CPU"
parsed_device.device_index = 0
return parsed_device.to_string()
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 37557343aa..456df2e6ea 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -21,15 +21,52 @@ from __future__ import print_function
import copy
-class Device(object):
- """Represents a Device."""
+class DeviceSpec(object):
+ """Represents a (possibly partial) specification for a TensorFlow device.
+
+ `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
+ and computations occur. Using `DeviceSpec` allows you to parse device spec
+ strings to verify their validity, merge them or compose them programmatically.
+
+ Example:
+ ```python
+ # Place the operations on device "GPU:0" in the "ps" job.
+ device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
+ with tf.device(device_spec):
+ # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
+ my_var = tf.Variable(..., name="my_variable")
+ squared_var = tf.square(my_var)
+ ```
+
+ If a `DeviceSpec` is partially specified, it will be merged with other
+ `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
+ components defined in inner scopes take precedence over those defined in
+ outer scopes.
+
+ ```python
+ with tf.device(DeviceSpec(job="train", )):
+ with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
+ # Nodes created here will be assigned to /job:ps/device:GPU:0.
+ with tf.device(DeviceSpec(device_type="GPU", device_index=1):
+ # Nodes created here will be assigned to /job:train/device:GPU:1.
+ ```
+
+ A `DeviceSpec` consists of 5 components -- each of
+ which is optionally specified:
+
+ * Job: The job name.
+ * Replica: The replica index.
+ * Task: The task index.
+ * Device type: The device type string (e.g. "CPU" or "GPU").
+ * Device index: The device index.
+ """
def __init__(self, job=None, replica=None, task=None, device_type=None,
device_index=None):
- """Create a new device object.
+ """Create a new `DeviceSpec` object.
Args:
- job: string. Optional device job name.
+ job: string. Optional job name.
replica: int. Optional replica index.
task: int. Optional task index.
device_type: Optional device type string (e.g. "CPU" or "GPU")
@@ -88,7 +125,7 @@ class Device(object):
self._task = None
def parse_from_string(self, spec):
- """Parse a Device name into its components.
+ """Parse a `DeviceSpec` name into its components.
Args:
spec: a string of the form
@@ -99,7 +136,7 @@ class Device(object):
All entries are optional.
Returns:
- The Device, for convenience.
+ The `DeviceSpec`.
Raises:
ValueError: if the spec was not valid.
@@ -135,10 +172,10 @@ class Device(object):
return self
def merge_from(self, dev):
- """Merge the properties of "dev" into this Device.
+ """Merge the properties of "dev" into this `DeviceSpec`.
Args:
- dev: a Device.
+ dev: a `DeviceSpec`.
"""
if dev.job is not None:
self.job = dev.job
@@ -152,11 +189,11 @@ class Device(object):
self.device_index = dev.device_index
def to_string(self):
- """Return a Device specification string.
+ """Return a string representation of this `DeviceSpec`.
Returns:
- a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id>
- or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>.
+ a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
"""
dev = ""
if self.job is not None:
@@ -172,22 +209,28 @@ class Device(object):
dev += "/device:%s:%s" % (self.device_type, device_index_string)
return dev
+ @staticmethod
+ def from_string(spec):
+ """Construct a `DeviceSpec` from a string.
+
+ Args:
+ spec: a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
+ or
+ /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
+ as cpu and gpu are mutually exclusive.
+ All entries are optional.
-def from_string(spec):
- """Construct a Device from a string.
+ Returns:
+ A DeviceSpec.
+ """
+ return DeviceSpec().parse_from_string(spec)
- Args:
- spec: a string of the form
- /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
- or
- /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
- as cpu and gpu are mutually exclusive.
- All entries are optional.
- Returns:
- A Device.
- """
- return Device().parse_from_string(spec)
+# For backwards compatibility.
+# TODO(rjryan): Fix all callers then remove.
+Device = DeviceSpec
+from_string = DeviceSpec.from_string
def check_valid(spec):
@@ -199,18 +242,18 @@ def check_valid(spec):
Raises:
An exception if the spec is invalid.
"""
- # Construct a device. It will assert a failure if spec is invalid.
- from_string(spec)
+ # Construct a DeviceSpec. It will assert a failure if spec is invalid.
+ DeviceSpec.from_string(spec)
def canonical_name(device):
- """Returns a canonical name for the given device or device name."""
+ """Returns a canonical name for the given `DeviceSpec` or device name."""
if device is None:
return ""
- if isinstance(device, Device):
+ if isinstance(device, DeviceSpec):
return device.to_string()
else:
- device = from_string(device)
+ device = DeviceSpec.from_string(device)
return device.to_string()
@@ -220,17 +263,17 @@ def merge_device(spec):
This can be used to merge partial specifications of devices. The
innermost setting for a device field takes precedence. For example:
- with tf.Device(MergeDevice("/device:GPU:0"))
+ with tf.device(merge_device("/device:GPU:0"))
# Nodes created here have device "/device:GPU:0"
- with tf.Device(MergeDevice("/job:worker")):
+ with tf.device(merge_device("/job:worker")):
# Nodes created here have device "/job:worker/device:GPU:0"
- with tf.Device(MergeDevice("/device:CPU:0")):
+ with tf.device(merge_device("/device:CPU:0")):
# Nodes created here have device "/job:worker/device:CPU:0"
- with tf.Device(MergeDevice("/job:ps")):
+ with tf.device(merge_device("/job:ps")):
# Nodes created here have device "/job:ps/device:CPU:0"
Args:
- spec: A device or a device spec string (partially) describing the
+ spec: A `DeviceSpec` or a device spec string (partially) describing the
device that should be used for all nodes created in the scope of
the returned device function's with block.
@@ -240,10 +283,10 @@ def merge_device(spec):
Raises:
ValueError: if the spec was not valid.
"""
- if not isinstance(spec, Device):
- spec = from_string(spec or "")
+ if not isinstance(spec, DeviceSpec):
+ spec = DeviceSpec.from_string(spec or "")
def _device_function(node_def):
- current_device = from_string(node_def.device or "")
+ current_device = DeviceSpec.from_string(node_def.device or "")
copy_spec = copy.copy(spec)
copy_spec.merge_from(current_device) # current_device takes precedence.
return copy_spec
diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py
index cb72770291..73cf60679f 100644
--- a/tensorflow/python/framework/device_test.py
+++ b/tensorflow/python/framework/device_test.py
@@ -26,14 +26,14 @@ from tensorflow.python.platform import googletest
class DeviceTest(test_util.TensorFlowTestCase):
def testEmpty(self):
- d = device.Device()
+ d = device.DeviceSpec()
self.assertEquals("", d.to_string())
d.parse_from_string("")
self.assertEquals("", d.to_string())
def testConstructor(self):
- d = device.Device(job="j", replica=0, task=1,
- device_type="CPU", device_index=2)
+ d = device.DeviceSpec(job="j", replica=0, task=1,
+ device_type="CPU", device_index=2)
self.assertEqual("j", d.job)
self.assertEqual(0, d.replica)
self.assertEqual(1, d.task)
@@ -41,11 +41,11 @@ class DeviceTest(test_util.TensorFlowTestCase):
self.assertEqual(2, d.device_index)
self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
- d = device.Device(device_type="GPU", device_index=0)
+ d = device.DeviceSpec(device_type="GPU", device_index=0)
self.assertEquals("/device:GPU:0", d.to_string())
def testto_string(self):
- d = device.Device()
+ d = device.DeviceSpec()
d.job = "foo"
self.assertEquals("/job:foo", d.to_string())
d.task = 3
@@ -68,11 +68,11 @@ class DeviceTest(test_util.TensorFlowTestCase):
self.assertEquals("/job:foo/replica:12", d.to_string())
# Test wildcard
- d = device.Device(job="foo", replica=12, task=3, device_type="GPU")
+ d = device.DeviceSpec(job="foo", replica=12, task=3, device_type="GPU")
self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def testParse(self):
- d = device.Device()
+ d = device.DeviceSpec()
d.parse_from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
d.parse_from_string("/replica:1/task:0/cpu:0")
@@ -86,33 +86,34 @@ class DeviceTest(test_util.TensorFlowTestCase):
self.assertTrue("Cannot specify multiple device" in str(e.exception))
def testFromString(self):
- d = device.from_string("/job:foo/replica:0")
+ d = device.DeviceSpec.from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
with self.assertRaises(Exception) as e:
- d = device.from_string("/job:muu/gpu:2/cpu:0")
+ d = device.DeviceSpec.from_string("/job:muu/gpu:2/cpu:0")
self.assertTrue("Cannot specify multiple device" in str(e.exception))
- d = device.from_string("/job:foo/replica:0/task:3/cpu:*")
+ d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/cpu:*")
self.assertEquals(None, d.device_index)
- d = device.from_string("/job:foo/replica:0/task:3/gpu:7")
+ d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/gpu:7")
self.assertEquals(7, d.device_index)
- d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7")
+ d = device.DeviceSpec.from_string("/job:foo/replica:0/task:3/device:GPU:7")
self.assertEquals(7, d.device_index)
def testMerge(self):
- d = device.from_string("/job:foo/replica:0")
+ d = device.DeviceSpec.from_string("/job:foo/replica:0")
self.assertEquals("/job:foo/replica:0", d.to_string())
- d.merge_from(device.from_string("/task:1/gpu:2"))
+ d.merge_from(device.DeviceSpec.from_string("/task:1/gpu:2"))
self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
- d = device.Device()
- d.merge_from(device.from_string("/task:1/cpu:0"))
+ d = device.DeviceSpec()
+ d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0"))
self.assertEquals("/task:1/device:CPU:0", d.to_string())
- d.merge_from(device.from_string("/job:boo/gpu:0"))
+ d.merge_from(device.DeviceSpec.from_string("/job:boo/gpu:0"))
self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
- d.merge_from(device.from_string("/job:muu/cpu:2"))
+ d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2"))
self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
- d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
+ d.merge_from(device.DeviceSpec.from_string(
+ "/job:muu/device:MyFunnyDevice:2"))
self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
def testCanonicalName(self):
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index f74749ad57..4f83e28fbb 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -67,6 +67,7 @@ from __future__ import division
from __future__ import print_function
# Classes used when building a Graph.
+from tensorflow.python.framework.device import DeviceSpec
from tensorflow.python.framework.ops import Graph
from tensorflow.python.framework.ops import Operation
from tensorflow.python.framework.ops import Tensor
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f3ba9687e6..effa64e77c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -987,7 +987,7 @@ SparseTensorValue = collections.namedtuple("SparseTensorValue",
def _device_string(dev_spec):
- if isinstance(dev_spec, pydev.Device):
+ if isinstance(dev_spec, pydev.DeviceSpec):
return dev_spec.to_string()
else:
return dev_spec
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 3a51e37e6c..f4e3be800c 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -114,7 +114,7 @@ class NodeDefConstructorTest(test_util.TensorFlowTestCase):
nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
nodedef)
- nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
+ nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j"))
self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
@@ -223,7 +223,8 @@ class OperationTest(test_util.TensorFlowTestCase):
"op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
op.node_def)
op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
- op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0))
+ op._set_device(pydev.DeviceSpec(job="muu", device_type="CPU",
+ device_index=0))
self.assertProtoEquals(
"op:'noop' name:'op2' device:'/job:muu/device:CPU:0'",
op.node_def)
@@ -526,9 +527,8 @@ class DeviceTest(test_util.TensorFlowTestCase):
def testDeviceFull(self):
g = ops.Graph()
- with g.device(pydev.Device(job="worker", replica=2, task=0,
- device_type="CPU",
- device_index=3)):
+ with g.device(pydev.DeviceSpec(job="worker", replica=2, task=0,
+ device_type="CPU", device_index=3)):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 624012f4de..15c342f38e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -548,8 +548,8 @@ class TensorFlowTestCase(googletest.TestCase):
"""Asserts that the two given devices are the same.
Args:
- device1: A string device name or TensorFlow `Device` object.
- device2: A string device name or TensorFlow `Device` object.
+ device1: A string device name or TensorFlow `DeviceSpec` object.
+ device2: A string device name or TensorFlow `DeviceSpec` object.
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index ce11d6a90e..0accacf8a4 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -73,14 +73,14 @@ class _ReplicaDeviceChooser(object):
"""
if not self._merge_devices and op.device:
return op.device
- current_device = pydev.from_string(op.device or "")
- spec = pydev.Device()
+ current_device = pydev.DeviceSpec.from_string(op.device or "")
+ spec = pydev.DeviceSpec()
if self._ps_tasks and self._ps_device:
node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def
if node_def.op in self._ps_ops:
device_string = "%s/task:%d" % (self._ps_device, self._next_task())
if self._merge_devices:
- spec = pydev.from_string(device_string)
+ spec = pydev.DeviceSpec.from_string(device_string)
spec.merge_from(current_device)
return spec.to_string()
else:
@@ -88,7 +88,7 @@ class _ReplicaDeviceChooser(object):
if self._worker_device:
if not self._merge_devices:
return self._worker_device
- spec = pydev.from_string(self._worker_device)
+ spec = pydev.DeviceSpec.from_string(self._worker_device)
if not self._merge_devices:
return ""