diff options
author | Olivia Nordquist <nolivia@google.com> | 2017-07-18 11:34:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-18 11:40:59 -0700 |
commit | 9867119831afa821aa53206cc5b7ae773db4284d (patch) | |
tree | a50ffe4e4e06d8b932721b7d501c39e7b3c6b559 | |
parent | 1558c10c8d07c78a58d74d0fcf53c1d2e2505d5e (diff) |
Implementing set_device for the C API
PiperOrigin-RevId: 162379684
-rw-r--r-- | tensorflow/c/python_api.cc | 5 | ||||
-rw-r--r-- | tensorflow/c/python_api.h | 2 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 4 | ||||
-rw-r--r-- | tensorflow/python/client/session_clusterspec_prop_test.py | 8 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops_test.py | 11 |
8 files changed, 21 insertions, 35 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 91e4e3cbb9..adca6c7625 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -25,4 +25,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { graph->graph.AddControlEdge(&input->node, &op->node); } +void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { + mutex_lock l(graph->mu); + op->node.set_requested_device(device); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index bb2b0fd8fd..e1a55d7755 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -25,6 +25,8 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a840ef39d2..f6586f0519 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -184,6 +184,11 @@ void Node::ClearAttr(const string& name) { (*props_->node_def.mutable_attr()).erase(name); } +void Node::set_requested_device(const string& device) { + MaybeCopyOnWrite(); + props_->node_def.set_device(device); +} + Status Node::input_edge(int idx, const Edge** e) const { if (idx < 0 || idx >= num_inputs()) { return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ", diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 565f430455..78a0e8fd79 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -101,6 +101,10 @@ class Node { // use assigned_device_name() below. const string& requested_device() const; + // This changes the user requested device but not necessarily the device that + // on which the operation will run. + void set_requested_device(const string& device); + // This gives the device the runtime has assigned this node to. If // you want the device the user requested, use def().device() instead. // TODO(josh11b): Validate that the assigned_device, if not empty: diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py index 37cbf8062c..f40d3f1872 100644 --- a/tensorflow/python/client/session_clusterspec_prop_test.py +++ b/tensorflow/python/client/session_clusterspec_prop_test.py @@ -49,7 +49,6 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationSimple(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server() @@ -65,7 +64,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): output = sess.run(const) self.assertEqual(17, output) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationWorker2Placement(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server() @@ -93,7 +91,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): dev_stats.device and 'Const' == node_stats.node_name ])) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationWorker1Placement(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server() @@ -110,7 +107,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): output = sess.run(const) self.assertEqual(17, output) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testCanonicalDeviceNames(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server() @@ -139,7 +135,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): dev_stats.device and 'Const' == node_stats.node_name ])) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testLegacyDeviceNames(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server() @@ -167,7 +162,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): dev_stats.device and 'Const' == node_stats.node_name ])) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationThreeServers2Graphs(self): """Boots 3 servers, creates 2 sessions, ensures appropriate operations. @@ -229,7 +223,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): self.assertAllEqual(expected_ones, sess2.run(var2)) self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationThreeServers(self): """Boots 3 servers, creates 2 sessions, ensures appropriate operations. @@ -284,7 +277,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): self.assertAllEqual(expected_ones, sess2.run(var)) self.assertAllEqual(expected_ones + expected_ones, sess1.run(var)) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testClusterSpecPropagationThreeServersOneCluster(self): """Boots 3 servers, ensures appropriate communication across workers. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 83c52b7cf7..61d411b6f9 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -64,7 +64,6 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) class SessionTest(test_util.TensorFlowTestCase): - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testUseExistingGraph(self): with ops.Graph().as_default() as g, ops.device('/cpu:0'): a = constant_op.constant(6.0, shape=[1, 1]) @@ -74,7 +73,6 @@ class SessionTest(test_util.TensorFlowTestCase): result = c.eval() self.assertAllEqual(result, [[42.0]]) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testUseDefaultGraph(self): with ops.Graph().as_default(), ops.device('/cpu:0'): a = constant_op.constant(6.0, shape=[1, 1]) @@ -879,7 +877,6 @@ class SessionTest(test_util.TensorFlowTestCase): v_val = v.eval() self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testExtendWithGroupBy(self): with session.Session() as s: a = constant_op.constant(1.0, shape=[1, 2]) @@ -1091,7 +1088,6 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): sess.run({}) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testNotEntered(self): # pylint: disable=protected-access self.assertEqual(ops._default_session_stack.get_default(), None) @@ -1107,7 +1103,6 @@ class SessionTest(test_util.TensorFlowTestCase): ValueError, lambda e: 'No default session is registered.' in str(e)): c_2.eval() - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testInteractive(self): with ops.device('/cpu:0'): sess = session.InteractiveSession() @@ -1120,7 +1115,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual([[24.0]], e.eval()) sess.close() - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testInteractivePlacePrunedGraph(self): sess = session.InteractiveSession() @@ -1142,7 +1136,6 @@ class SessionTest(test_util.TensorFlowTestCase): a.eval() sess.close() - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testDefaultSessionPlacePrunedGraph(self): sess = session.Session() @@ -1164,7 +1157,6 @@ class SessionTest(test_util.TensorFlowTestCase): sess.close() - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testSharedGraph(self): with ops.Graph().as_default() as g, ops.device('/cpu:0'): a = constant_op.constant(1.0, shape=[1, 2]) @@ -1423,7 +1415,6 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'): sess.run(a, feed_dict={'a': [2.0]}) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testPerStepTrace(self): run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) @@ -1444,7 +1435,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertTrue(run_metadata.HasField('step_stats')) self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testRunOptionsRunMetadata(self): run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) @@ -1480,7 +1470,6 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'may not be fed'): sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testInferShapesFalse(self): with ops.Graph().as_default(), ops.device('/cpu:0'): a = constant_op.constant([[1, 2]]) @@ -1489,7 +1478,6 @@ class SessionTest(test_util.TensorFlowTestCase): # Avoid lint error regarding 'unused' var a. self.assertTrue(a == a) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testInferShapesTrue(self): config = config_pb2.ConfigProto( graph_options=config_pb2.GraphOptions(infer_shapes=True)) @@ -1500,7 +1488,6 @@ class SessionTest(test_util.TensorFlowTestCase): # Avoid lint error regarding 'unused' var a. self.assertTrue(a == a) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testBuildCostModel(self): run_options = config_pb2.RunOptions() config = config_pb2.ConfigProto( diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 8216b65c02..e70716d316 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1373,7 +1373,7 @@ class Operation(object): """ if self._graph._c_graph: # pylint: disable=protected-access # TODO(iga): Remove this assert after converting to C API by default. - # Just being a bit paranoid here. + # Just being a bit paranoid here assert self._node_def.device == c_api.TF_OperationDevice(self._c_op) return c_api.TF_OperationDevice(self._c_op) else: @@ -1427,8 +1427,10 @@ class Operation(object): Args: device: string or device.. The device to set. """ - assert not self._graph._c_graph, ( # pylint: disable=protected-access - "Operation._set_device doesn't work with C API") + if _USE_C_API: + c_api.SetRequestedDevice( + self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access + # TODO(nolivia): remove this line when switch to C api self._node_def.device = _device_string(device) def _add_input(self, tensor, dtype=None): diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 9683603785..617d2305bd 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -224,7 +224,6 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): class ScalarMulTest(test_util.TensorFlowTestCase): - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testAcceptsRefs(self): var = variables.Variable(10) result = math_ops.scalar_mul(3, var) @@ -327,7 +326,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): divs = np.arange(-3, 0, .25).reshape(1, 12) return nums, divs - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testFloorModInt(self): nums, divs = self.intTestData() with self.test_session(): @@ -337,7 +335,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): np_result = nums % divs self.assertAllEqual(tf_result, np_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testFloorModFloat(self): nums, divs = self.floatTestData() with self.test_session(): @@ -349,7 +346,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): # % array_ops.constant(divs)).eval() # self.assertAllEqual(tf2_result, tf_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testTruncateModInt(self): nums, divs = self.intTestData() with self.test_session(): @@ -357,7 +353,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): np_result = np.fmod(nums, divs) self.assertAllEqual(tf_result, np_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testTruncateModFloat(self): nums, divs = self.floatTestData() with self.test_session(): @@ -365,7 +360,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): np_result = np.fmod(nums, divs) self.assertAllEqual(tf_result, np_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testDivideInt(self): nums, divs = self.intTestData() with self.test_session(): @@ -377,14 +371,12 @@ class DivAndModTest(test_util.TensorFlowTestCase): # // array_ops.constant(divs)).eval() # self.assertAllEqual(tf2_result, tf_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testDivideName(self): with self.test_session(): op = math_ops.divide( array_ops.constant(3), array_ops.constant(4), name="my_cool_divide") self.assertEqual(op.name, "my_cool_divide:0") - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testRealDiv(self): nums, divs = self.floatTestData() with self.test_session(): @@ -392,14 +384,12 @@ class DivAndModTest(test_util.TensorFlowTestCase): np_result = np.divide(nums, divs) self.assertAllEqual(tf_result, np_result) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testComplexDiv(self): foo = array_ops.constant([1. + 3.j]) with self.test_session(): _ = math_ops.divide(foo, 1.).eval() _ = math_ops.div(foo, 2.).eval() - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testFloorDivGrad(self): with self.test_session(): a = variables.Variable(2.) @@ -414,7 +404,6 @@ class DivAndModTest(test_util.TensorFlowTestCase): self.assertAllEqual([None if x is None else x.eval() for x in c_grad], [None, None]) - @test_util.disable_c_api # Operation._set_device doesn't work with C API def testConsistent(self): nums, divs = self.intTestData() with self.test_session(): |