diff options
author | 2016-02-05 18:19:07 -0800 | |
---|---|---|
committer | 2016-02-06 08:46:41 -0800 | |
commit | 3fa9676bbc41826689e9b0e11a45e3fbdceae258 (patch) | |
tree | dcb787db2d252ccced947e0d5cf2f7733c931668 /tensorflow | |
parent | 241698b6ba6cd9b13d606a9e4603baa4f33891f2 (diff) |
Consolidate the device function and device string handling in `tf.device()`.
The effect of this CL is to treat `with tf.device(device_name):` as
supplying a device function that *merges* `device_name` into the
device of ops created in that scope. (Merging is defined by
`tensorflow.python.framework.device.merge_device()`: essentially, for
each field defined in `device_name`, the merge function sets an op's
device to that if it has not already been set.) This makes it possible
to compose device blocks that set different parts of a device, and use
device strings in composition with device functions.
A secondary effect of this CL is that it causes `with
tf.device(None):` to interoperate properly with device functions. As
with other `tf.Graph` contexts, entering a `with tf.device(None):` now
has the effect of ignoring all currently set device functions in the
outer context.
This CL makes some breaking changes to corner cases in the
`tf.device()`, `tf.Graph`, `tf.Operation`, and `tf.Tensor` APIs:
* Within a `with tf.device(device_string):` scope, the given device
string will now be *merged* into the device for ops created in that
scope. See the implementation of
`tensorflow.python.framework.device.merge_device()` for details.
Previously, device strings were maintained in a single "default
device" field, rather than a stack, so device strings from outer
contexts would be completely ignored. To obtain the previous
behavior, use `with tf.device(None), tf.device(device_string):`
instead.
* Within a `with tf.Graph.device(None):` scope, no device functions
from the outer context will be executed.
Previously, the `None` applied only to the device string, and all
device functions would be applied unconditionally.
* The `tf.Graph.get_default_device()` method is removed, because it no
longer has a well-defined meaning. To create a no-op device scope,
you can simply use `with tf.device(""):`.
* The `tf.Operation.device` and `tf.Tensor.device` properties now
return an empty string when no device has been set for an op. This
makes it easier to write code like `with tf.device(op.device):`,
which is robust to `op` having or not having a device (in which case
the scope should be a no-op).
Change: 114003979
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/client/graph_util_test.py | 84 | ||||
-rw-r--r-- | tensorflow/python/framework/device.py | 11 | ||||
-rw-r--r-- | tensorflow/python/framework/device_test.py | 37 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 90 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 122 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 13 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/init_ops_test.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/numerics.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/op_def_library_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 4 |
17 files changed, 274 insertions, 144 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 114aba2434..8bb9f892ef 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -343,6 +343,18 @@ py_test( ) py_test( + name = "framework_device_test", + srcs = ["framework/device_test.py"], + main = "framework/device_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_test_lib", + ":platform_test", + "//tensorflow/core:protos_all_py", + ], +) + +py_test( name = "framework_tensor_shape_div_test", srcs = ["framework/tensor_shape_div_test.py"], main = "framework/tensor_shape_div_test.py", diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py index 73265361cd..49379ffab1 100644 --- a/tensorflow/python/client/graph_util_test.py +++ b/tensorflow/python/client/graph_util_test.py @@ -31,10 +31,9 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops # pylint: enable=unused-import from tensorflow.python.ops import state_ops -from tensorflow.python.platform import googletest -class DeviceFunctionsTest(googletest.TestCase): +class DeviceFunctionsTest(tf.test.TestCase): def testPinToCpu(self): with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu): @@ -48,14 +47,14 @@ class DeviceFunctionsTest(googletest.TestCase): [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]]) dynamic_stitch_float_result = data_flow_ops.dynamic_stitch( [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]) - self.assertEqual(const_a.device, "/device:CPU:0") - self.assertEqual(const_b.device, "/device:CPU:0") - self.assertEqual(add_c.device, "/device:CPU:0") - self.assertEqual(var_v.device, "/device:CPU:0") - self.assertEqual(assign_c_to_v.device, "/device:CPU:0") - self.assertEqual(const_string.device, "/device:CPU:0") - self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0") - self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0") + self.assertDeviceEqual(const_a.device, "/device:CPU:0") + self.assertDeviceEqual(const_b.device, "/device:CPU:0") + self.assertDeviceEqual(add_c.device, "/device:CPU:0") + self.assertDeviceEqual(var_v.device, "/device:CPU:0") + self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0") + self.assertDeviceEqual(const_string.device, "/device:CPU:0") + self.assertDeviceEqual(dynamic_stitch_int_result.device, "/device:CPU:0") + self.assertDeviceEqual(dynamic_stitch_float_result.device, "/device:CPU:0") def testPinRequiredOpsOnCPU(self): with ops.Graph().as_default() as g, g.device( @@ -70,12 +69,12 @@ class DeviceFunctionsTest(googletest.TestCase): dynamic_stitch_float_result = data_flow_ops.dynamic_stitch( [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]) # Non-variable ops shuld not specify a device - self.assertEqual(const_a.device, None) - self.assertEqual(const_b.device, None) - self.assertEqual(add_c.device, None) + self.assertDeviceEqual(const_a.device, None) + self.assertDeviceEqual(const_b.device, None) + self.assertDeviceEqual(add_c.device, None) # Variable ops specify a device - self.assertEqual(var_v.device, "/device:CPU:0") - self.assertEqual(assign_c_to_v.device, "/device:CPU:0") + self.assertDeviceEqual(var_v.device, "/device:CPU:0") + self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0") def testTwoDeviceFunctions(self): with ops.Graph().as_default() as g: @@ -90,13 +89,28 @@ class DeviceFunctionsTest(googletest.TestCase): var_5 = state_ops.variable_op([1], dtype=dtypes.float32) var_6 = state_ops.variable_op([1], dtype=dtypes.float32) - self.assertEqual(var_0.device, None) - self.assertEqual(var_1.device, "/device:CPU:0") - self.assertEqual(var_2.device, None) - self.assertEqual(var_3.device, None) - self.assertEqual(var_4.device, "/device:CPU:0") - self.assertEqual(var_5.device, "/device:GPU:0") - self.assertEqual(var_6.device, "/device:CPU:0") + self.assertDeviceEqual(var_0.device, None) + self.assertDeviceEqual(var_1.device, "/device:CPU:0") + self.assertDeviceEqual(var_2.device, None) + self.assertDeviceEqual(var_3.device, None) + self.assertDeviceEqual(var_4.device, "/device:CPU:0") + self.assertDeviceEqual(var_5.device, "/device:GPU:0") + self.assertDeviceEqual(var_6.device, "/device:CPU:0") + + def testNestedDeviceFunctions(self): + with tf.Graph().as_default(): + var_0 = tf.Variable(0) + with tf.device(graph_util.pin_variables_on_cpu): + var_1 = tf.Variable(1) + with tf.device(lambda op: "/gpu:0"): + var_2 = tf.Variable(2) + with tf.device("/gpu:0"): # Implicit merging device function. + var_3 = tf.Variable(3) + + self.assertDeviceEqual(var_0.device, None) + self.assertDeviceEqual(var_1.device, "/device:CPU:0") + self.assertDeviceEqual(var_2.device, "/device:GPU:0") + self.assertDeviceEqual(var_3.device, "/device:GPU:0") def testExplicitDevice(self): with ops.Graph().as_default() as g: @@ -112,12 +126,12 @@ class DeviceFunctionsTest(googletest.TestCase): with g.device("/job:ps"): const_5 = constant_op.constant(5.0) - self.assertEqual(const_0.device, None) - self.assertEqual(const_1.device, "/device:GPU:0") - self.assertEqual(const_2.device, "/device:GPU:1") - self.assertEqual(const_3.device, "/device:CPU:0") - self.assertEqual(const_4.device, "/device:CPU:1") - self.assertEqual(const_5.device, "/job:ps") + self.assertDeviceEqual(const_0.device, None) + self.assertDeviceEqual(const_1.device, "/device:GPU:0") + self.assertDeviceEqual(const_2.device, "/device:GPU:1") + self.assertDeviceEqual(const_3.device, "/device:CPU:0") + self.assertDeviceEqual(const_4.device, "/device:CPU:1") + self.assertDeviceEqual(const_5.device, "/job:ps") def testDefaultDevice(self): with ops.Graph().as_default() as g, g.device( @@ -135,12 +149,12 @@ class DeviceFunctionsTest(googletest.TestCase): with g.device("/replica:0"): const_5 = constant_op.constant(5.0) - self.assertEqual(const_0.device, "/job:ps") - self.assertEqual(const_1.device, "/device:GPU:0") - self.assertEqual(const_2.device, "/device:GPU:1") - self.assertEqual(const_3.device, "/device:CPU:0") - self.assertEqual(const_4.device, "/device:CPU:1") - self.assertEqual(const_5.device, "/replica:0") + self.assertDeviceEqual(const_0.device, "/job:ps") + self.assertDeviceEqual(const_1.device, "/device:GPU:0") + self.assertDeviceEqual(const_2.device, "/device:GPU:1") + self.assertDeviceEqual(const_3.device, "/device:CPU:0") + self.assertDeviceEqual(const_4.device, "/device:CPU:1") + self.assertDeviceEqual(const_5.device, "/replica:0") def testExtractSubGraph(self): graph_def = tf.GraphDef() @@ -172,4 +186,4 @@ class DeviceFunctionsTest(googletest.TestCase): if __name__ == "__main__": - googletest.main() + tf.test.main() diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py index 557b9544fb..37557343aa 100644 --- a/tensorflow/python/framework/device.py +++ b/tensorflow/python/framework/device.py @@ -203,6 +203,17 @@ def check_valid(spec): from_string(spec) +def canonical_name(device): + """Returns a canonical name for the given device or device name.""" + if device is None: + return "" + if isinstance(device, Device): + return device.to_string() + else: + device = from_string(device) + return device.to_string() + + def merge_device(spec): """Returns a device function that merges devices specifications. diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py index ded5bec1d6..f3c22acd7a 100644 --- a/tensorflow/python/framework/device_test.py +++ b/tensorflow/python/framework/device_test.py @@ -29,9 +29,9 @@ class DeviceTest(test_util.TensorFlowTestCase): def testEmpty(self): d = device.Device() - self.assertEquals("", d.ToString()) + self.assertEquals("", d.to_string()) d.parse_from_string("") - self.assertEquals("", d.ToString()) + self.assertEquals("", d.to_string()) def testConstructor(self): d = device.Device(job="j", replica=0, task=1, @@ -117,23 +117,46 @@ class DeviceTest(test_util.TensorFlowTestCase): d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2")) self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) + def testCanonicalName(self): + self.assertEqual("/job:foo/replica:0", + device.canonical_name("/job:foo/replica:0")) + self.assertEqual("/job:foo/replica:0", + device.canonical_name("/replica:0/job:foo")) + + self.assertEqual("/job:foo/replica:0/task:0", + device.canonical_name("/job:foo/replica:0/task:0")) + self.assertEqual("/job:foo/replica:0/task:0", + device.canonical_name("/job:foo/task:0/replica:0")) + + self.assertEqual("/device:CPU:0", + device.canonical_name("/device:CPU:0")) + self.assertEqual("/device:GPU:2", + device.canonical_name("/device:GPU:2")) + + self.assertEqual("/job:foo/replica:0/task:0/device:GPU:0", + device.canonical_name( + "/job:foo/replica:0/task:0/gpu:0")) + self.assertEqual("/job:foo/replica:0/task:0/device:GPU:0", + device.canonical_name( + "/gpu:0/task:0/replica:0/job:foo")) + def testCheckValid(self): - device.CheckValid("/job:foo/replica:0") + device.check_valid("/job:foo/replica:0") with self.assertRaises(Exception) as e: - device.CheckValid("/job:j/replica:foo") + device.check_valid("/job:j/replica:foo") self.assertTrue("invalid literal for int" in e.exception.message) with self.assertRaises(Exception) as e: - device.CheckValid("/job:j/task:bar") + device.check_valid("/job:j/task:bar") self.assertTrue("invalid literal for int" in e.exception.message) with self.assertRaises(Exception) as e: - device.CheckValid("/bar:muu/baz:2") + device.check_valid("/bar:muu/baz:2") self.assertTrue("Unknown attribute: 'bar'" in e.exception.message) with self.assertRaises(Exception) as e: - device.CheckValid("/cpu:0/gpu:2") + device.check_valid("/cpu:0/gpu:2") self.assertTrue("Cannot specify multiple device" in e.exception.message) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 981a70f40d..17a6d19a0f 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1135,10 +1135,10 @@ class Operation(object): Returns: The string name of the device to which this op has been - assigned, or None if it has not been assigned to a device. + assigned, or an empty string if it has not been assigned to a + device. """ - dev = self._node_def.device - return None if not dev else dev + return self._node_def.device def _set_device(self, device): """Set the device of this operation. @@ -1737,7 +1737,6 @@ class Graph(object): @@get_tensor_by_name @@get_operations - @@get_default_device @@seed @@unique_name @@version @@ -1757,8 +1756,6 @@ class Graph(object): self._name_stack = ("", "") # Maps a name used in the graph to the next id to use for that name. self._names_in_use = {} - # Default device applied to new ops. - self._default_device = None # Functions that will be applied to choose a device if none is specified. self._device_function_stack = [] # Default original_op applied to new ops. @@ -2013,8 +2010,7 @@ class Graph(object): else: name = self.unique_name(name) - node_def = _NodeDef( - op_type, name, device=self._default_device or None, attrs=attrs) + node_def = _NodeDef(op_type, name, device=None, attrs=attrs) # Apply a kernel label if one has been specified for this op_type. try: @@ -2046,6 +2042,8 @@ class Graph(object): # We apply here because the result can depend on the Operation's # signature, which is computed in the Operation constructor. for device_function in reversed(self._device_function_stack): + if device_function is None: + break ret._set_device(device_function(ret)) return ret @@ -2489,54 +2487,6 @@ class Graph(object): else: return name - def _set_default_device(self, dev): - """Set the default device properties. - - Args: - dev: string or Device. - """ - self._default_device = _device_string(dev) - - def get_default_device(self): - """Returns the default device. - - Returns: - A string. - """ - return self._default_device - - def _push_default_device_function(self, device_function): - """Pushes the given function onto the stack of device functions. - - See `Graph.device` for more details. - - Args: - device_function: The function to be pushed onto the stack of device - functions. - """ - self._device_function_stack.append(device_function) - - def _pop_default_device_function(self, device_function): - """Pops the given function from the stack of device functions. - - See `Graph.device` for more details. - - Args: - device_function: The function to be popped from the stack of device - functions. - - Raises: - ValueError: if the device_function to be popped is not top of the stack, - or if the stack is empty. - """ - if not self._device_function_stack: - raise ValueError("Tried to pop, but the device function stack is empty") - if self._device_function_stack[-1] is not device_function: - raise ValueError("Tried to pop device function, but it was not on top " - "of the stack") - - self._device_function_stack.pop() - @contextlib.contextmanager def device(self, device_name_or_function): """Returns a context manager that specifies the default device to use. @@ -2545,12 +2495,14 @@ class Graph(object): string, a device function, or None: * If it is a device name string, all operations constructed in - this context will be assigned to the device with that name. + this context will be assigned to the device with that name, unless + overridden by a nested `device()` context. * If it is a function, it will be treated as function from Operation objects to device name strings, and invoked each time a new Operation is created. The Operation will be assigned to the device with the returned name. - * If it is None, the default device will be cleared. + * If it is None, all `device()` invocations from the enclosing context + will be ignored. For example: @@ -2583,19 +2535,17 @@ class Graph(object): A context manager that specifies the default device to use for newly created ops. """ - if callable(device_name_or_function): - try: - self._push_default_device_function(device_name_or_function) - yield - finally: - self._pop_default_device_function(device_name_or_function) + if (device_name_or_function is not None + and not callable(device_name_or_function)): + device_function = pydev.merge_device(device_name_or_function) else: - try: - old_dev = self.get_default_device() - self._set_default_device(_device_string(device_name_or_function)) - yield - finally: - self._set_default_device(old_dev) + device_function = device_name_or_function + + try: + self._device_function_stack.append(device_function) + yield + finally: + self._device_function_stack.pop() class _ControlDependenciesController(object): """Context manager for `control_dependencies()`.""" diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 7676b0c239..e5fa618675 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -256,7 +256,7 @@ class CreateOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() op1 = g.create_op("const", [], [dtypes.float32], None, name="myop1") - with g.device("/device:GPU"): + with g.device("/device:GPU:0"): op2 = g.create_op("add", [], [dtypes.float32, dtypes.string], None, @@ -267,11 +267,11 @@ class CreateOpTest(test_util.TensorFlowTestCase): [dtypes.float32, dtypes.int32], None, name="myop3") - self.assertEqual(None, op1.device) - self.assertEqual("/device:GPU", op2.device) - self.assertEqual(None, op3.device) + self.assertDeviceEqual(None, op1.device) + self.assertDeviceEqual("/device:GPU:0", op2.device) + self.assertDeviceEqual(None, op3.device) self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def) - self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU:0'", op2.node_def) self.assertProtoEquals( "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", @@ -305,7 +305,7 @@ class ApplyOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() t1 = _apply_op(g, "const", [], [dtypes.float32], name="myop1") - with g.device("/device:GPU"): + with g.device("/device:GPU:0"): t2 = _apply_op(g, "add", [], [dtypes.float32, dtypes.string], @@ -322,7 +322,7 @@ class ApplyOpTest(test_util.TensorFlowTestCase): self.assertEqual("myop3", t3[0]._as_node_def_input()) # Validate that we got the right ops as well self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def) - self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU:0'", t2[0].op.node_def) self.assertProtoEquals( "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", @@ -451,7 +451,7 @@ class DeviceTest(test_util.TensorFlowTestCase): def testNoDevice(self): g = ops.Graph() op = g.create_op("an_op", [], [dtypes.float32]) - self.assertEqual(None, op.device) + self.assertDeviceEqual(None, op.device) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "an_op" op: "an_op" } @@ -557,6 +557,59 @@ class DeviceTest(test_util.TensorFlowTestCase): device: "/job:ps/device:CPU:0" } """, gd) + def testNestingWithDeviceStrings(self): + g = ops.Graph() + + with g.device("/device:GPU:0"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/job:worker"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/device:CPU:0"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/job:ps"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device(""): + g.create_op("an_op", [], [dtypes.float32]) + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "an_op" op: "an_op" + device: "/device:GPU:0" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/device:GPU:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/device:CPU:0" } + node { name: "an_op_3" op: "an_op" + device: "/job:ps/device:CPU:0" } + node { name: "an_op_4" op: "an_op" + device: "/job:ps/device:CPU:0" } + """, gd) + + def testNestingWithDeviceStringWildcard(self): + g = ops.Graph() + + with g.device("/device:GPU:7"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/device:GPU:*"): + g.create_op("an_op", [], [dtypes.float32]) + + with g.device("/device:CPU:*"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/device:CPU:5"): + g.create_op("an_op", [], [dtypes.float32]) + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "an_op" op: "an_op" + device: "/device:GPU:7" } + node { name: "an_op_1" op: "an_op" + device: "/device:GPU:7" } + node { name: "an_op_2" op: "an_op" + device: "/device:CPU:*" } + node { name: "an_op_3" op: "an_op" + device: "/device:CPU:5" } + """, gd) + def testNoneClearsDefault(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): @@ -573,6 +626,59 @@ class DeviceTest(test_util.TensorFlowTestCase): device: "/job:worker/replica:2/device:CPU:1" } """, gd) + def testNoneIgnoresOuterDeviceFunction(self): + g = ops.Graph() + with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device(None): + g.create_op("an_op", [], [dtypes.float32]) + g.create_op("an_op", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "an_op_1" op: "an_op" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def _overwritingDeviceFunction(self, unused_op): + # This device function unconditionally overwrites the device of ops. + # + # NOTE(mrry): Writing device functions like this is not + # recommended. Instead, in most cases you should use + # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the + # argument to `tf.device()` and the device component will be merged in. + return "/job:overwrite" + + def testOverwritingBehavior(self): + g = ops.Graph() + with g.device(self._overwritingDeviceFunction): + g.create_op("an_op", [], [dtypes.float32]) + with g.device("/job:ps"): # Will be overwritten. + g.create_op("an_op", [], [dtypes.float32]) + with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. + g.create_op("an_op", [], [dtypes.float32]) + with g.device(None): # Disables overwriting device function + with g.device("/job:ps"): + g.create_op("an_op", [], [dtypes.float32]) + with g.device(None): # Disables overwriting device function + with g.device(pydev.merge_device("/job:ps")): + g.create_op("an_op", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "an_op" op: "an_op" + device: "/job:overwrite" } + node { name: "an_op_1" op: "an_op" + device: "/job:overwrite" } + node { name: "an_op_2" op: "an_op" + device: "/job:overwrite" } + node { name: "an_op_3" op: "an_op" + device: "/job:ps" } + node { name: "an_op_4" op: "an_op" + device: "/job:ps" } + """, gd) + class ObjectWithName(object): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 9eea140d8b..c004e8eb2e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -37,6 +37,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import graph_util from tensorflow.python.client import session +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import versions @@ -524,6 +525,18 @@ class TensorFlowTestCase(googletest.TestCase): raise TypeError("tf_tensor must be a Tensor") self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list()) + def assertDeviceEqual(self, device1, device2): + """Asserts that the two given devices are the same. + + Args: + device1: A string device name or TensorFlow `Device` object. + device2: A string device name or TensorFlow `Device` object. + """ + device1 = pydev.canonical_name(device1) + device2 = pydev.canonical_name(device2) + self.assertEqual(device1, device2, + "Devices %s and %s are not equal" % (device1, device2)) + # Fix Python 3 compatibility issues if six.PY3: # Silence a deprecation warning diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index c315959572..da97d8f436 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -416,7 +416,7 @@ class ControlFlowTest(tf.test.TestCase): for op in x.graph.get_operations(): if op.name == "cond/Add/Switch": - self.assertEqual(op.device, "/cpu:0") + self.assertDeviceEqual(op.device, "/cpu:0") def _testCond_1(self, use_gpu): with self.test_session(use_gpu=use_gpu): @@ -1390,14 +1390,14 @@ class ControlFlowTest(tf.test.TestCase): vnod = tf.Variable([0.0]) with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer], vnod) - self.assertEquals(None, with_vnod_dep.device) + self.assertDeviceEqual(None, with_vnod_dep.device) # device set on tensor, default device on graph => default device on dep. vdef = tf.Variable([0.0]) with tf.device("/job:worker/gpu:1"): with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer], vdef) - self.assertEquals("/job:worker/gpu:1", with_vdef_dep.device) + self.assertDeviceEqual("/job:worker/gpu:1", with_vdef_dep.device) def testGroup(self): with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 145058579e..28aa3eccc3 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -288,15 +288,15 @@ class DeviceTest(tf.test.TestCase): def testNoDevice(self): with tf.Graph().as_default(): var = tf.Variable([[1.0, 1.0]]) - self.assertEqual(None, var.device) - self.assertEqual(None, var.initializer.device) + self.assertDeviceEqual(None, var.device) + self.assertDeviceEqual(None, var.initializer.device) def testDevice(self): with tf.Graph().as_default(): with tf.device("/job:ps"): var = tf.Variable([[1.0, 1.0]]) - self.assertEqual("/job:ps", var.device) - self.assertEqual("/job:ps", var.initializer.device) + self.assertDeviceEqual("/job:ps", var.device) + self.assertDeviceEqual("/job:ps", var.initializer.device) if __name__ == "__main__": diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 6d3df749dc..e932a6e626 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1612,8 +1612,7 @@ def with_dependencies(dependencies, output_tensor, name=None): """ with ops.op_scope(dependencies + [output_tensor], name, "control_dependency") as name: - with ops.device(output_tensor.device - or ops.get_default_graph().get_default_device()): + with ops.device(output_tensor.device): with ops.control_dependencies(dependencies): output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor) if isinstance(output_tensor, ops.Tensor): diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 78ecc98859..a495baebc7 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -158,7 +158,7 @@ def _GetGradsDevice(op, colocate_gradients_with_ops): if colocate_gradients_with_ops and op.device: return op.device else: - return op.graph.get_default_device() + return "" def _PendingCount(graph, to_ops, from_ops): diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index e5a828e7bb..e194dabee4 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -167,7 +167,7 @@ class GradientsTest(test_util.TensorFlowTestCase): with g.device("/gpu:0"): wx = math_ops.matmul(w, x) gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] - self.assertEquals("/gpu:0", gw.device) + self.assertDeviceEqual("/gpu:0", gw.device) def testColocateGradientsWithAggregation(self): with ops.Graph().as_default() as g: @@ -180,9 +180,9 @@ class GradientsTest(test_util.TensorFlowTestCase): with g.device("/gpu:0"): z = wx + wy gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] - self.assertEquals("/gpu:1", gw1.device) + self.assertDeviceEqual("/gpu:1", gw1.device) gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] - self.assertEquals(None, gw2.device) + self.assertDeviceEqual(None, gw2.device) def testBoundaryStop(self): # Test that we don't differentiate 'x'. The gradient function for 'x' is diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index 8525da35f5..ebe7af3678 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -38,7 +38,7 @@ def verify_tensor_all_finite(t, msg, name=None): """ with ops.op_scope([t], name, "VerifyFinite") as name: t = ops.convert_to_tensor(t, name="t") - with ops.device(t.device or t.graph.get_default_device()): + with ops.device(t.device): verify_input = array_ops.check_numerics(t, message=msg) out = control_flow_ops.with_dependencies([verify_input], t) return out diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py index 0f733b786a..8dcef974e8 100644 --- a/tensorflow/python/ops/op_def_library_test.py +++ b/tensorflow/python/ops/op_def_library_test.py @@ -1387,14 +1387,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "Input 'a' of 'RefIn' Op requires l-value input") def testSpecifyDevice(self): - with self._g.device("ADevice"): + with self._g.device("/job:ADevice"): self._lib.apply_op("Simple", a=3) # We look at the whole graph here to make sure the Const op is also given # the specified device. graph_def = self._g.as_graph_def() self.assertEqual(len(graph_def.node), 2) for node in graph_def.node: - self.assertEqual(node.device, "ADevice") + self.assertDeviceEqual(node.device, "/job:ADevice") def testStructuredOutputSingleList(self): self._add_op("name: 'SimpleStruct' " diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 5d0f89037d..b775adcfe4 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -161,7 +161,7 @@ def init_variable(v, init, name="init"): """ with ops.op_scope([v, init], None, v.op.name + "/"): with ops.name_scope(name) as scope: - with ops.device(v.device or ops.get_default_graph().get_default_device()): + with ops.device(v.device): if callable(init): assert v.get_shape().is_fully_defined(), "Variable shape unknown." # TODO(mrry): Convert to v.shape when the property and diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index b5650135a6..5b739b15d4 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -179,17 +179,17 @@ class ExponentialMovingAverageTest(tf.test.TestCase): self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name) def testAverageVariablesDeviceAssignment(self): - with tf.device("dev_v0"): + with tf.device("/job:dev_v0"): v0 = tf.Variable(10.0, name="v0") - with tf.device("dev_v1"): + with tf.device("/job:dev_v1"): v1 = state_ops.variable_op(shape=[1], dtype=tf.float32, name="v1") tensor2 = v0 + v1 ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg") - with tf.device("default"): + with tf.device("/job:default"): ema.apply([v0, v1, tensor2]) - self.assertEqual("dev_v0", ema.average(v0).device) - self.assertEqual("dev_v1", ema.average(v1).device) - self.assertEqual("default", ema.average(tensor2).device) + self.assertDeviceEqual("/job:dev_v0", ema.average(v0).device) + self.assertDeviceEqual("/job:dev_v1", ema.average(v1).device) + self.assertDeviceEqual("/job:default", ema.average(tensor2).device) if __name__ == "__main__": diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2ce65aefdf..34ad21bbfb 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -30,6 +30,7 @@ from google.protobuf import text_format from tensorflow.python.client import graph_util from tensorflow.python.client import session +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op @@ -252,7 +253,8 @@ class BaseSaverBuilder(object): """ per_device = collections.defaultdict(lambda: []) for var_to_save in vars_to_save: - per_device[var_to_save.var.device].append(var_to_save) + canonical_device = pydev.canonical_name(var_to_save.var.device) + per_device[canonical_device].append(var_to_save) return sorted(per_device.items(), key=lambda t: t[0]) def _VarListToDict(self, var_list): |