aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-02-05 18:19:07 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-02-06 08:46:41 -0800
commit3fa9676bbc41826689e9b0e11a45e3fbdceae258 (patch)
treedcb787db2d252ccced947e0d5cf2f7733c931668 /tensorflow
parent241698b6ba6cd9b13d606a9e4603baa4f33891f2 (diff)
Consolidate the device function and device string handling in `tf.device()`.
The effect of this CL is to treat `with tf.device(device_name):` as supplying a device function that *merges* `device_name` into the device of ops created in that scope. (Merging is defined by `tensorflow.python.framework.device.merge_device()`: essentially, for each field defined in `device_name`, the merge function sets an op's device to that if it has not already been set.) This makes it possible to compose device blocks that set different parts of a device, and use device strings in composition with device functions. A secondary effect of this CL is that it causes `with tf.device(None):` to interoperate properly with device functions. As with other `tf.Graph` contexts, entering a `with tf.device(None):` now has the effect of ignoring all currently set device functions in the outer context. This CL makes some breaking changes to corner cases in the `tf.device()`, `tf.Graph`, `tf.Operation`, and `tf.Tensor` APIs: * Within a `with tf.device(device_string):` scope, the given device string will now be *merged* into the device for ops created in that scope. See the implementation of `tensorflow.python.framework.device.merge_device()` for details. Previously, device strings were maintained in a single "default device" field, rather than a stack, so device strings from outer contexts would be completely ignored. To obtain the previous behavior, use `with tf.device(None), tf.device(device_string):` instead. * Within a `with tf.Graph.device(None):` scope, no device functions from the outer context will be executed. Previously, the `None` applied only to the device string, and all device functions would be applied unconditionally. * The `tf.Graph.get_default_device()` method is removed, because it no longer has a well-defined meaning. To create a no-op device scope, you can simply use `with tf.device(""):`. * The `tf.Operation.device` and `tf.Tensor.device` properties now return an empty string when no device has been set for an op. This makes it easier to write code like `with tf.device(op.device):`, which is robust to `op` having or not having a device (in which case the scope should be a no-op). Change: 114003979
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/BUILD12
-rw-r--r--tensorflow/python/client/graph_util_test.py84
-rw-r--r--tensorflow/python/framework/device.py11
-rw-r--r--tensorflow/python/framework/device_test.py37
-rw-r--r--tensorflow/python/framework/ops.py90
-rw-r--r--tensorflow/python/framework/ops_test.py122
-rw-r--r--tensorflow/python/framework/test_util.py13
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py6
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py8
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
-rw-r--r--tensorflow/python/ops/gradients.py2
-rw-r--r--tensorflow/python/ops/gradients_test.py6
-rw-r--r--tensorflow/python/ops/numerics.py2
-rw-r--r--tensorflow/python/ops/op_def_library_test.py4
-rw-r--r--tensorflow/python/ops/state_ops.py2
-rw-r--r--tensorflow/python/training/moving_averages_test.py12
-rw-r--r--tensorflow/python/training/saver.py4
17 files changed, 274 insertions, 144 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 114aba2434..8bb9f892ef 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -343,6 +343,18 @@ py_test(
)
py_test(
+ name = "framework_device_test",
+ srcs = ["framework/device_test.py"],
+ main = "framework/device_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
name = "framework_tensor_shape_div_test",
srcs = ["framework/tensor_shape_div_test.py"],
main = "framework/tensor_shape_div_test.py",
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py
index 73265361cd..49379ffab1 100644
--- a/tensorflow/python/client/graph_util_test.py
+++ b/tensorflow/python/client/graph_util_test.py
@@ -31,10 +31,9 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
# pylint: enable=unused-import
from tensorflow.python.ops import state_ops
-from tensorflow.python.platform import googletest
-class DeviceFunctionsTest(googletest.TestCase):
+class DeviceFunctionsTest(tf.test.TestCase):
def testPinToCpu(self):
with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
@@ -48,14 +47,14 @@ class DeviceFunctionsTest(googletest.TestCase):
[[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
- self.assertEqual(const_a.device, "/device:CPU:0")
- self.assertEqual(const_b.device, "/device:CPU:0")
- self.assertEqual(add_c.device, "/device:CPU:0")
- self.assertEqual(var_v.device, "/device:CPU:0")
- self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
- self.assertEqual(const_string.device, "/device:CPU:0")
- self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
- self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
+ self.assertDeviceEqual(const_a.device, "/device:CPU:0")
+ self.assertDeviceEqual(const_b.device, "/device:CPU:0")
+ self.assertDeviceEqual(add_c.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_v.device, "/device:CPU:0")
+ self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0")
+ self.assertDeviceEqual(const_string.device, "/device:CPU:0")
+ self.assertDeviceEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
+ self.assertDeviceEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
def testPinRequiredOpsOnCPU(self):
with ops.Graph().as_default() as g, g.device(
@@ -70,12 +69,12 @@ class DeviceFunctionsTest(googletest.TestCase):
dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
# Non-variable ops shuld not specify a device
- self.assertEqual(const_a.device, None)
- self.assertEqual(const_b.device, None)
- self.assertEqual(add_c.device, None)
+ self.assertDeviceEqual(const_a.device, None)
+ self.assertDeviceEqual(const_b.device, None)
+ self.assertDeviceEqual(add_c.device, None)
# Variable ops specify a device
- self.assertEqual(var_v.device, "/device:CPU:0")
- self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_v.device, "/device:CPU:0")
+ self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0")
def testTwoDeviceFunctions(self):
with ops.Graph().as_default() as g:
@@ -90,13 +89,28 @@ class DeviceFunctionsTest(googletest.TestCase):
var_5 = state_ops.variable_op([1], dtype=dtypes.float32)
var_6 = state_ops.variable_op([1], dtype=dtypes.float32)
- self.assertEqual(var_0.device, None)
- self.assertEqual(var_1.device, "/device:CPU:0")
- self.assertEqual(var_2.device, None)
- self.assertEqual(var_3.device, None)
- self.assertEqual(var_4.device, "/device:CPU:0")
- self.assertEqual(var_5.device, "/device:GPU:0")
- self.assertEqual(var_6.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_0.device, None)
+ self.assertDeviceEqual(var_1.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_2.device, None)
+ self.assertDeviceEqual(var_3.device, None)
+ self.assertDeviceEqual(var_4.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_5.device, "/device:GPU:0")
+ self.assertDeviceEqual(var_6.device, "/device:CPU:0")
+
+ def testNestedDeviceFunctions(self):
+ with tf.Graph().as_default():
+ var_0 = tf.Variable(0)
+ with tf.device(graph_util.pin_variables_on_cpu):
+ var_1 = tf.Variable(1)
+ with tf.device(lambda op: "/gpu:0"):
+ var_2 = tf.Variable(2)
+ with tf.device("/gpu:0"): # Implicit merging device function.
+ var_3 = tf.Variable(3)
+
+ self.assertDeviceEqual(var_0.device, None)
+ self.assertDeviceEqual(var_1.device, "/device:CPU:0")
+ self.assertDeviceEqual(var_2.device, "/device:GPU:0")
+ self.assertDeviceEqual(var_3.device, "/device:GPU:0")
def testExplicitDevice(self):
with ops.Graph().as_default() as g:
@@ -112,12 +126,12 @@ class DeviceFunctionsTest(googletest.TestCase):
with g.device("/job:ps"):
const_5 = constant_op.constant(5.0)
- self.assertEqual(const_0.device, None)
- self.assertEqual(const_1.device, "/device:GPU:0")
- self.assertEqual(const_2.device, "/device:GPU:1")
- self.assertEqual(const_3.device, "/device:CPU:0")
- self.assertEqual(const_4.device, "/device:CPU:1")
- self.assertEqual(const_5.device, "/job:ps")
+ self.assertDeviceEqual(const_0.device, None)
+ self.assertDeviceEqual(const_1.device, "/device:GPU:0")
+ self.assertDeviceEqual(const_2.device, "/device:GPU:1")
+ self.assertDeviceEqual(const_3.device, "/device:CPU:0")
+ self.assertDeviceEqual(const_4.device, "/device:CPU:1")
+ self.assertDeviceEqual(const_5.device, "/job:ps")
def testDefaultDevice(self):
with ops.Graph().as_default() as g, g.device(
@@ -135,12 +149,12 @@ class DeviceFunctionsTest(googletest.TestCase):
with g.device("/replica:0"):
const_5 = constant_op.constant(5.0)
- self.assertEqual(const_0.device, "/job:ps")
- self.assertEqual(const_1.device, "/device:GPU:0")
- self.assertEqual(const_2.device, "/device:GPU:1")
- self.assertEqual(const_3.device, "/device:CPU:0")
- self.assertEqual(const_4.device, "/device:CPU:1")
- self.assertEqual(const_5.device, "/replica:0")
+ self.assertDeviceEqual(const_0.device, "/job:ps")
+ self.assertDeviceEqual(const_1.device, "/device:GPU:0")
+ self.assertDeviceEqual(const_2.device, "/device:GPU:1")
+ self.assertDeviceEqual(const_3.device, "/device:CPU:0")
+ self.assertDeviceEqual(const_4.device, "/device:CPU:1")
+ self.assertDeviceEqual(const_5.device, "/replica:0")
def testExtractSubGraph(self):
graph_def = tf.GraphDef()
@@ -172,4 +186,4 @@ class DeviceFunctionsTest(googletest.TestCase):
if __name__ == "__main__":
- googletest.main()
+ tf.test.main()
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 557b9544fb..37557343aa 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -203,6 +203,17 @@ def check_valid(spec):
from_string(spec)
+def canonical_name(device):
+ """Returns a canonical name for the given device or device name."""
+ if device is None:
+ return ""
+ if isinstance(device, Device):
+ return device.to_string()
+ else:
+ device = from_string(device)
+ return device.to_string()
+
+
def merge_device(spec):
"""Returns a device function that merges devices specifications.
diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py
index ded5bec1d6..f3c22acd7a 100644
--- a/tensorflow/python/framework/device_test.py
+++ b/tensorflow/python/framework/device_test.py
@@ -29,9 +29,9 @@ class DeviceTest(test_util.TensorFlowTestCase):
def testEmpty(self):
d = device.Device()
- self.assertEquals("", d.ToString())
+ self.assertEquals("", d.to_string())
d.parse_from_string("")
- self.assertEquals("", d.ToString())
+ self.assertEquals("", d.to_string())
def testConstructor(self):
d = device.Device(job="j", replica=0, task=1,
@@ -117,23 +117,46 @@ class DeviceTest(test_util.TensorFlowTestCase):
d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
+ def testCanonicalName(self):
+ self.assertEqual("/job:foo/replica:0",
+ device.canonical_name("/job:foo/replica:0"))
+ self.assertEqual("/job:foo/replica:0",
+ device.canonical_name("/replica:0/job:foo"))
+
+ self.assertEqual("/job:foo/replica:0/task:0",
+ device.canonical_name("/job:foo/replica:0/task:0"))
+ self.assertEqual("/job:foo/replica:0/task:0",
+ device.canonical_name("/job:foo/task:0/replica:0"))
+
+ self.assertEqual("/device:CPU:0",
+ device.canonical_name("/device:CPU:0"))
+ self.assertEqual("/device:GPU:2",
+ device.canonical_name("/device:GPU:2"))
+
+ self.assertEqual("/job:foo/replica:0/task:0/device:GPU:0",
+ device.canonical_name(
+ "/job:foo/replica:0/task:0/gpu:0"))
+ self.assertEqual("/job:foo/replica:0/task:0/device:GPU:0",
+ device.canonical_name(
+ "/gpu:0/task:0/replica:0/job:foo"))
+
def testCheckValid(self):
- device.CheckValid("/job:foo/replica:0")
+ device.check_valid("/job:foo/replica:0")
with self.assertRaises(Exception) as e:
- device.CheckValid("/job:j/replica:foo")
+ device.check_valid("/job:j/replica:foo")
self.assertTrue("invalid literal for int" in e.exception.message)
with self.assertRaises(Exception) as e:
- device.CheckValid("/job:j/task:bar")
+ device.check_valid("/job:j/task:bar")
self.assertTrue("invalid literal for int" in e.exception.message)
with self.assertRaises(Exception) as e:
- device.CheckValid("/bar:muu/baz:2")
+ device.check_valid("/bar:muu/baz:2")
self.assertTrue("Unknown attribute: 'bar'" in e.exception.message)
with self.assertRaises(Exception) as e:
- device.CheckValid("/cpu:0/gpu:2")
+ device.check_valid("/cpu:0/gpu:2")
self.assertTrue("Cannot specify multiple device" in e.exception.message)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 981a70f40d..17a6d19a0f 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1135,10 +1135,10 @@ class Operation(object):
Returns:
The string name of the device to which this op has been
- assigned, or None if it has not been assigned to a device.
+ assigned, or an empty string if it has not been assigned to a
+ device.
"""
- dev = self._node_def.device
- return None if not dev else dev
+ return self._node_def.device
def _set_device(self, device):
"""Set the device of this operation.
@@ -1737,7 +1737,6 @@ class Graph(object):
@@get_tensor_by_name
@@get_operations
- @@get_default_device
@@seed
@@unique_name
@@version
@@ -1757,8 +1756,6 @@ class Graph(object):
self._name_stack = ("", "")
# Maps a name used in the graph to the next id to use for that name.
self._names_in_use = {}
- # Default device applied to new ops.
- self._default_device = None
# Functions that will be applied to choose a device if none is specified.
self._device_function_stack = []
# Default original_op applied to new ops.
@@ -2013,8 +2010,7 @@ class Graph(object):
else:
name = self.unique_name(name)
- node_def = _NodeDef(
- op_type, name, device=self._default_device or None, attrs=attrs)
+ node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
# Apply a kernel label if one has been specified for this op_type.
try:
@@ -2046,6 +2042,8 @@ class Graph(object):
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
for device_function in reversed(self._device_function_stack):
+ if device_function is None:
+ break
ret._set_device(device_function(ret))
return ret
@@ -2489,54 +2487,6 @@ class Graph(object):
else:
return name
- def _set_default_device(self, dev):
- """Set the default device properties.
-
- Args:
- dev: string or Device.
- """
- self._default_device = _device_string(dev)
-
- def get_default_device(self):
- """Returns the default device.
-
- Returns:
- A string.
- """
- return self._default_device
-
- def _push_default_device_function(self, device_function):
- """Pushes the given function onto the stack of device functions.
-
- See `Graph.device` for more details.
-
- Args:
- device_function: The function to be pushed onto the stack of device
- functions.
- """
- self._device_function_stack.append(device_function)
-
- def _pop_default_device_function(self, device_function):
- """Pops the given function from the stack of device functions.
-
- See `Graph.device` for more details.
-
- Args:
- device_function: The function to be popped from the stack of device
- functions.
-
- Raises:
- ValueError: if the device_function to be popped is not top of the stack,
- or if the stack is empty.
- """
- if not self._device_function_stack:
- raise ValueError("Tried to pop, but the device function stack is empty")
- if self._device_function_stack[-1] is not device_function:
- raise ValueError("Tried to pop device function, but it was not on top "
- "of the stack")
-
- self._device_function_stack.pop()
-
@contextlib.contextmanager
def device(self, device_name_or_function):
"""Returns a context manager that specifies the default device to use.
@@ -2545,12 +2495,14 @@ class Graph(object):
string, a device function, or None:
* If it is a device name string, all operations constructed in
- this context will be assigned to the device with that name.
+ this context will be assigned to the device with that name, unless
+ overridden by a nested `device()` context.
* If it is a function, it will be treated as function from
Operation objects to device name strings, and invoked each time
a new Operation is created. The Operation will be assigned to
the device with the returned name.
- * If it is None, the default device will be cleared.
+ * If it is None, all `device()` invocations from the enclosing context
+ will be ignored.
For example:
@@ -2583,19 +2535,17 @@ class Graph(object):
A context manager that specifies the default device to use for newly
created ops.
"""
- if callable(device_name_or_function):
- try:
- self._push_default_device_function(device_name_or_function)
- yield
- finally:
- self._pop_default_device_function(device_name_or_function)
+ if (device_name_or_function is not None
+ and not callable(device_name_or_function)):
+ device_function = pydev.merge_device(device_name_or_function)
else:
- try:
- old_dev = self.get_default_device()
- self._set_default_device(_device_string(device_name_or_function))
- yield
- finally:
- self._set_default_device(old_dev)
+ device_function = device_name_or_function
+
+ try:
+ self._device_function_stack.append(device_function)
+ yield
+ finally:
+ self._device_function_stack.pop()
class _ControlDependenciesController(object):
"""Context manager for `control_dependencies()`."""
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 7676b0c239..e5fa618675 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -256,7 +256,7 @@ class CreateOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
g = ops.Graph()
op1 = g.create_op("const", [], [dtypes.float32], None, name="myop1")
- with g.device("/device:GPU"):
+ with g.device("/device:GPU:0"):
op2 = g.create_op("add",
[],
[dtypes.float32, dtypes.string], None,
@@ -267,11 +267,11 @@ class CreateOpTest(test_util.TensorFlowTestCase):
[dtypes.float32, dtypes.int32],
None,
name="myop3")
- self.assertEqual(None, op1.device)
- self.assertEqual("/device:GPU", op2.device)
- self.assertEqual(None, op3.device)
+ self.assertDeviceEqual(None, op1.device)
+ self.assertDeviceEqual("/device:GPU:0", op2.device)
+ self.assertDeviceEqual(None, op3.device)
self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def)
- self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU:0'",
op2.node_def)
self.assertProtoEquals(
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
@@ -305,7 +305,7 @@ class ApplyOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
g = ops.Graph()
t1 = _apply_op(g, "const", [], [dtypes.float32], name="myop1")
- with g.device("/device:GPU"):
+ with g.device("/device:GPU:0"):
t2 = _apply_op(g, "add",
[],
[dtypes.float32, dtypes.string],
@@ -322,7 +322,7 @@ class ApplyOpTest(test_util.TensorFlowTestCase):
self.assertEqual("myop3", t3[0]._as_node_def_input())
# Validate that we got the right ops as well
self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def)
- self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU:0'",
t2[0].op.node_def)
self.assertProtoEquals(
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
@@ -451,7 +451,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self):
g = ops.Graph()
op = g.create_op("an_op", [], [dtypes.float32])
- self.assertEqual(None, op.device)
+ self.assertDeviceEqual(None, op.device)
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op" }
@@ -557,6 +557,59 @@ class DeviceTest(test_util.TensorFlowTestCase):
device: "/job:ps/device:CPU:0" }
""", gd)
+ def testNestingWithDeviceStrings(self):
+ g = ops.Graph()
+
+ with g.device("/device:GPU:0"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/job:worker"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/device:CPU:0"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/job:ps"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device(""):
+ g.create_op("an_op", [], [dtypes.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "an_op" op: "an_op"
+ device: "/device:GPU:0" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "an_op_4" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNestingWithDeviceStringWildcard(self):
+ g = ops.Graph()
+
+ with g.device("/device:GPU:7"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/device:GPU:*"):
+ g.create_op("an_op", [], [dtypes.float32])
+
+ with g.device("/device:CPU:*"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/device:CPU:5"):
+ g.create_op("an_op", [], [dtypes.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "an_op" op: "an_op"
+ device: "/device:GPU:7" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/device:GPU:7" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/device:CPU:*" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/device:CPU:5" }
+ """, gd)
+
def testNoneClearsDefault(self):
g = ops.Graph()
with g.device("/job:worker/replica:2/device:CPU:1"):
@@ -573,6 +626,59 @@ class DeviceTest(test_util.TensorFlowTestCase):
device: "/job:worker/replica:2/device:CPU:1" }
""", gd)
+ def testNoneIgnoresOuterDeviceFunction(self):
+ g = ops.Graph()
+ with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device(None):
+ g.create_op("an_op", [], [dtypes.float32])
+ g.create_op("an_op", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def _overwritingDeviceFunction(self, unused_op):
+ # This device function unconditionally overwrites the device of ops.
+ #
+ # NOTE(mrry): Writing device functions like this is not
+ # recommended. Instead, in most cases you should use
+ # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
+ # argument to `tf.device()` and the device component will be merged in.
+ return "/job:overwrite"
+
+ def testOverwritingBehavior(self):
+ g = ops.Graph()
+ with g.device(self._overwritingDeviceFunction):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device("/job:ps"): # Will be overwritten.
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device(pydev.merge_device("/job:ps")): # Will be overwritten.
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device(None): # Disables overwriting device function
+ with g.device("/job:ps"):
+ g.create_op("an_op", [], [dtypes.float32])
+ with g.device(None): # Disables overwriting device function
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("an_op", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:overwrite" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:overwrite" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:overwrite" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/job:ps" }
+ node { name: "an_op_4" op: "an_op"
+ device: "/job:ps" }
+ """, gd)
+
class ObjectWithName(object):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 9eea140d8b..c004e8eb2e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -37,6 +37,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import graph_util
from tensorflow.python.client import session
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
@@ -524,6 +525,18 @@ class TensorFlowTestCase(googletest.TestCase):
raise TypeError("tf_tensor must be a Tensor")
self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list())
+ def assertDeviceEqual(self, device1, device2):
+ """Asserts that the two given devices are the same.
+
+ Args:
+ device1: A string device name or TensorFlow `Device` object.
+ device2: A string device name or TensorFlow `Device` object.
+ """
+ device1 = pydev.canonical_name(device1)
+ device2 = pydev.canonical_name(device2)
+ self.assertEqual(device1, device2,
+ "Devices %s and %s are not equal" % (device1, device2))
+
# Fix Python 3 compatibility issues
if six.PY3:
# Silence a deprecation warning
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index c315959572..da97d8f436 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -416,7 +416,7 @@ class ControlFlowTest(tf.test.TestCase):
for op in x.graph.get_operations():
if op.name == "cond/Add/Switch":
- self.assertEqual(op.device, "/cpu:0")
+ self.assertDeviceEqual(op.device, "/cpu:0")
def _testCond_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
@@ -1390,14 +1390,14 @@ class ControlFlowTest(tf.test.TestCase):
vnod = tf.Variable([0.0])
with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
vnod)
- self.assertEquals(None, with_vnod_dep.device)
+ self.assertDeviceEqual(None, with_vnod_dep.device)
# device set on tensor, default device on graph => default device on dep.
vdef = tf.Variable([0.0])
with tf.device("/job:worker/gpu:1"):
with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
vdef)
- self.assertEquals("/job:worker/gpu:1", with_vdef_dep.device)
+ self.assertDeviceEqual("/job:worker/gpu:1", with_vdef_dep.device)
def testGroup(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 145058579e..28aa3eccc3 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -288,15 +288,15 @@ class DeviceTest(tf.test.TestCase):
def testNoDevice(self):
with tf.Graph().as_default():
var = tf.Variable([[1.0, 1.0]])
- self.assertEqual(None, var.device)
- self.assertEqual(None, var.initializer.device)
+ self.assertDeviceEqual(None, var.device)
+ self.assertDeviceEqual(None, var.initializer.device)
def testDevice(self):
with tf.Graph().as_default():
with tf.device("/job:ps"):
var = tf.Variable([[1.0, 1.0]])
- self.assertEqual("/job:ps", var.device)
- self.assertEqual("/job:ps", var.initializer.device)
+ self.assertDeviceEqual("/job:ps", var.device)
+ self.assertDeviceEqual("/job:ps", var.initializer.device)
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 6d3df749dc..e932a6e626 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1612,8 +1612,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
"""
with ops.op_scope(dependencies + [output_tensor], name,
"control_dependency") as name:
- with ops.device(output_tensor.device
- or ops.get_default_graph().get_default_device()):
+ with ops.device(output_tensor.device):
with ops.control_dependencies(dependencies):
output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
if isinstance(output_tensor, ops.Tensor):
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 78ecc98859..a495baebc7 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -158,7 +158,7 @@ def _GetGradsDevice(op, colocate_gradients_with_ops):
if colocate_gradients_with_ops and op.device:
return op.device
else:
- return op.graph.get_default_device()
+ return ""
def _PendingCount(graph, to_ops, from_ops):
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index e5a828e7bb..e194dabee4 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -167,7 +167,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
with g.device("/gpu:0"):
wx = math_ops.matmul(w, x)
gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
- self.assertEquals("/gpu:0", gw.device)
+ self.assertDeviceEqual("/gpu:0", gw.device)
def testColocateGradientsWithAggregation(self):
with ops.Graph().as_default() as g:
@@ -180,9 +180,9 @@ class GradientsTest(test_util.TensorFlowTestCase):
with g.device("/gpu:0"):
z = wx + wy
gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
- self.assertEquals("/gpu:1", gw1.device)
+ self.assertDeviceEqual("/gpu:1", gw1.device)
gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
- self.assertEquals(None, gw2.device)
+ self.assertDeviceEqual(None, gw2.device)
def testBoundaryStop(self):
# Test that we don't differentiate 'x'. The gradient function for 'x' is
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index 8525da35f5..ebe7af3678 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -38,7 +38,7 @@ def verify_tensor_all_finite(t, msg, name=None):
"""
with ops.op_scope([t], name, "VerifyFinite") as name:
t = ops.convert_to_tensor(t, name="t")
- with ops.device(t.device or t.graph.get_default_device()):
+ with ops.device(t.device):
verify_input = array_ops.check_numerics(t, message=msg)
out = control_flow_ops.with_dependencies([verify_input], t)
return out
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
index 0f733b786a..8dcef974e8 100644
--- a/tensorflow/python/ops/op_def_library_test.py
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -1387,14 +1387,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"Input 'a' of 'RefIn' Op requires l-value input")
def testSpecifyDevice(self):
- with self._g.device("ADevice"):
+ with self._g.device("/job:ADevice"):
self._lib.apply_op("Simple", a=3)
# We look at the whole graph here to make sure the Const op is also given
# the specified device.
graph_def = self._g.as_graph_def()
self.assertEqual(len(graph_def.node), 2)
for node in graph_def.node:
- self.assertEqual(node.device, "ADevice")
+ self.assertDeviceEqual(node.device, "/job:ADevice")
def testStructuredOutputSingleList(self):
self._add_op("name: 'SimpleStruct' "
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 5d0f89037d..b775adcfe4 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -161,7 +161,7 @@ def init_variable(v, init, name="init"):
"""
with ops.op_scope([v, init], None, v.op.name + "/"):
with ops.name_scope(name) as scope:
- with ops.device(v.device or ops.get_default_graph().get_default_device()):
+ with ops.device(v.device):
if callable(init):
assert v.get_shape().is_fully_defined(), "Variable shape unknown."
# TODO(mrry): Convert to v.shape when the property and
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index b5650135a6..5b739b15d4 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -179,17 +179,17 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
def testAverageVariablesDeviceAssignment(self):
- with tf.device("dev_v0"):
+ with tf.device("/job:dev_v0"):
v0 = tf.Variable(10.0, name="v0")
- with tf.device("dev_v1"):
+ with tf.device("/job:dev_v1"):
v1 = state_ops.variable_op(shape=[1], dtype=tf.float32, name="v1")
tensor2 = v0 + v1
ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
- with tf.device("default"):
+ with tf.device("/job:default"):
ema.apply([v0, v1, tensor2])
- self.assertEqual("dev_v0", ema.average(v0).device)
- self.assertEqual("dev_v1", ema.average(v1).device)
- self.assertEqual("default", ema.average(tensor2).device)
+ self.assertDeviceEqual("/job:dev_v0", ema.average(v0).device)
+ self.assertDeviceEqual("/job:dev_v1", ema.average(v1).device)
+ self.assertDeviceEqual("/job:default", ema.average(tensor2).device)
if __name__ == "__main__":
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2ce65aefdf..34ad21bbfb 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -30,6 +30,7 @@ from google.protobuf import text_format
from tensorflow.python.client import graph_util
from tensorflow.python.client import session
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
@@ -252,7 +253,8 @@ class BaseSaverBuilder(object):
"""
per_device = collections.defaultdict(lambda: [])
for var_to_save in vars_to_save:
- per_device[var_to_save.var.device].append(var_to_save)
+ canonical_device = pydev.canonical_name(var_to_save.var.device)
+ per_device[canonical_device].append(var_to_save)
return sorted(per_device.items(), key=lambda t: t[0])
def _VarListToDict(self, var_list):