aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-07-18 11:34:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 11:40:59 -0700
commit9867119831afa821aa53206cc5b7ae773db4284d (patch)
treea50ffe4e4e06d8b932721b7d501c39e7b3c6b559
parent1558c10c8d07c78a58d74d0fcf53c1d2e2505d5e (diff)
Implementing set_device for the C API
PiperOrigin-RevId: 162379684
-rw-r--r--tensorflow/c/python_api.cc5
-rw-r--r--tensorflow/c/python_api.h2
-rw-r--r--tensorflow/core/graph/graph.cc5
-rw-r--r--tensorflow/core/graph/graph.h4
-rw-r--r--tensorflow/python/client/session_clusterspec_prop_test.py8
-rw-r--r--tensorflow/python/client/session_test.py13
-rw-r--r--tensorflow/python/framework/ops.py8
-rw-r--r--tensorflow/python/ops/math_ops_test.py11
8 files changed, 21 insertions, 35 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 91e4e3cbb9..adca6c7625 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -25,4 +25,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
graph->graph.AddControlEdge(&input->node, &op->node);
}
+void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
+ mutex_lock l(graph->mu);
+ op->node.set_requested_device(device);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index bb2b0fd8fd..e1a55d7755 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -25,6 +25,8 @@ namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
+void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index a840ef39d2..f6586f0519 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -184,6 +184,11 @@ void Node::ClearAttr(const string& name) {
(*props_->node_def.mutable_attr()).erase(name);
}
+void Node::set_requested_device(const string& device) {
+ MaybeCopyOnWrite();
+ props_->node_def.set_device(device);
+}
+
Status Node::input_edge(int idx, const Edge** e) const {
if (idx < 0 || idx >= num_inputs()) {
return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 565f430455..78a0e8fd79 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -101,6 +101,10 @@ class Node {
// use assigned_device_name() below.
const string& requested_device() const;
+ // This changes the user requested device but not necessarily the device that
+ // on which the operation will run.
+ void set_requested_device(const string& device);
+
// This gives the device the runtime has assigned this node to. If
// you want the device the user requested, use def().device() instead.
// TODO(josh11b): Validate that the assigned_device, if not empty:
diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py
index 37cbf8062c..f40d3f1872 100644
--- a/tensorflow/python/client/session_clusterspec_prop_test.py
+++ b/tensorflow/python/client/session_clusterspec_prop_test.py
@@ -49,7 +49,6 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationSimple(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
@@ -65,7 +64,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
output = sess.run(const)
self.assertEqual(17, output)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationWorker2Placement(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
@@ -93,7 +91,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
dev_stats.device and 'Const' == node_stats.node_name
]))
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationWorker1Placement(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
@@ -110,7 +107,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
output = sess.run(const)
self.assertEqual(17, output)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testCanonicalDeviceNames(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
@@ -139,7 +135,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
dev_stats.device and 'Const' == node_stats.node_name
]))
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testLegacyDeviceNames(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
@@ -167,7 +162,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
dev_stats.device and 'Const' == node_stats.node_name
]))
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationThreeServers2Graphs(self):
"""Boots 3 servers, creates 2 sessions, ensures appropriate operations.
@@ -229,7 +223,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(expected_ones, sess2.run(var2))
self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationThreeServers(self):
"""Boots 3 servers, creates 2 sessions, ensures appropriate operations.
@@ -284,7 +277,6 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(expected_ones, sess2.run(var))
self.assertAllEqual(expected_ones + expected_ones, sess1.run(var))
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testClusterSpecPropagationThreeServersOneCluster(self):
"""Boots 3 servers, ensures appropriate communication across workers.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 83c52b7cf7..61d411b6f9 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -64,7 +64,6 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
class SessionTest(test_util.TensorFlowTestCase):
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testUseExistingGraph(self):
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
a = constant_op.constant(6.0, shape=[1, 1])
@@ -74,7 +73,6 @@ class SessionTest(test_util.TensorFlowTestCase):
result = c.eval()
self.assertAllEqual(result, [[42.0]])
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testUseDefaultGraph(self):
with ops.Graph().as_default(), ops.device('/cpu:0'):
a = constant_op.constant(6.0, shape=[1, 1])
@@ -879,7 +877,6 @@ class SessionTest(test_util.TensorFlowTestCase):
v_val = v.eval()
self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testExtendWithGroupBy(self):
with session.Session() as s:
a = constant_op.constant(1.0, shape=[1, 2])
@@ -1091,7 +1088,6 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
sess.run({})
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testNotEntered(self):
# pylint: disable=protected-access
self.assertEqual(ops._default_session_stack.get_default(), None)
@@ -1107,7 +1103,6 @@ class SessionTest(test_util.TensorFlowTestCase):
ValueError, lambda e: 'No default session is registered.' in str(e)):
c_2.eval()
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testInteractive(self):
with ops.device('/cpu:0'):
sess = session.InteractiveSession()
@@ -1120,7 +1115,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[24.0]], e.eval())
sess.close()
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testInteractivePlacePrunedGraph(self):
sess = session.InteractiveSession()
@@ -1142,7 +1136,6 @@ class SessionTest(test_util.TensorFlowTestCase):
a.eval()
sess.close()
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testDefaultSessionPlacePrunedGraph(self):
sess = session.Session()
@@ -1164,7 +1157,6 @@ class SessionTest(test_util.TensorFlowTestCase):
sess.close()
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testSharedGraph(self):
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
a = constant_op.constant(1.0, shape=[1, 2])
@@ -1423,7 +1415,6 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'):
sess.run(a, feed_dict={'a': [2.0]})
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testPerStepTrace(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1444,7 +1435,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testRunOptionsRunMetadata(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1480,7 +1470,6 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'may not be fed'):
sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testInferShapesFalse(self):
with ops.Graph().as_default(), ops.device('/cpu:0'):
a = constant_op.constant([[1, 2]])
@@ -1489,7 +1478,6 @@ class SessionTest(test_util.TensorFlowTestCase):
# Avoid lint error regarding 'unused' var a.
self.assertTrue(a == a)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testInferShapesTrue(self):
config = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(infer_shapes=True))
@@ -1500,7 +1488,6 @@ class SessionTest(test_util.TensorFlowTestCase):
# Avoid lint error regarding 'unused' var a.
self.assertTrue(a == a)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testBuildCostModel(self):
run_options = config_pb2.RunOptions()
config = config_pb2.ConfigProto(
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8216b65c02..e70716d316 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1373,7 +1373,7 @@ class Operation(object):
"""
if self._graph._c_graph: # pylint: disable=protected-access
# TODO(iga): Remove this assert after converting to C API by default.
- # Just being a bit paranoid here.
+ # Just being a bit paranoid here
assert self._node_def.device == c_api.TF_OperationDevice(self._c_op)
return c_api.TF_OperationDevice(self._c_op)
else:
@@ -1427,8 +1427,10 @@ class Operation(object):
Args:
device: string or device.. The device to set.
"""
- assert not self._graph._c_graph, ( # pylint: disable=protected-access
- "Operation._set_device doesn't work with C API")
+ if _USE_C_API:
+ c_api.SetRequestedDevice(
+ self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access
+ # TODO(nolivia): remove this line when switch to C api
self._node_def.device = _device_string(device)
def _add_input(self, tensor, dtype=None):
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 9683603785..617d2305bd 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -224,7 +224,6 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
class ScalarMulTest(test_util.TensorFlowTestCase):
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testAcceptsRefs(self):
var = variables.Variable(10)
result = math_ops.scalar_mul(3, var)
@@ -327,7 +326,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
divs = np.arange(-3, 0, .25).reshape(1, 12)
return nums, divs
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testFloorModInt(self):
nums, divs = self.intTestData()
with self.test_session():
@@ -337,7 +335,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = nums % divs
self.assertAllEqual(tf_result, np_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testFloorModFloat(self):
nums, divs = self.floatTestData()
with self.test_session():
@@ -349,7 +346,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
# % array_ops.constant(divs)).eval()
# self.assertAllEqual(tf2_result, tf_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testTruncateModInt(self):
nums, divs = self.intTestData()
with self.test_session():
@@ -357,7 +353,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = np.fmod(nums, divs)
self.assertAllEqual(tf_result, np_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testTruncateModFloat(self):
nums, divs = self.floatTestData()
with self.test_session():
@@ -365,7 +360,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = np.fmod(nums, divs)
self.assertAllEqual(tf_result, np_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testDivideInt(self):
nums, divs = self.intTestData()
with self.test_session():
@@ -377,14 +371,12 @@ class DivAndModTest(test_util.TensorFlowTestCase):
# // array_ops.constant(divs)).eval()
# self.assertAllEqual(tf2_result, tf_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testDivideName(self):
with self.test_session():
op = math_ops.divide(
array_ops.constant(3), array_ops.constant(4), name="my_cool_divide")
self.assertEqual(op.name, "my_cool_divide:0")
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testRealDiv(self):
nums, divs = self.floatTestData()
with self.test_session():
@@ -392,14 +384,12 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = np.divide(nums, divs)
self.assertAllEqual(tf_result, np_result)
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testComplexDiv(self):
foo = array_ops.constant([1. + 3.j])
with self.test_session():
_ = math_ops.divide(foo, 1.).eval()
_ = math_ops.div(foo, 2.).eval()
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testFloorDivGrad(self):
with self.test_session():
a = variables.Variable(2.)
@@ -414,7 +404,6 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual([None if x is None else x.eval()
for x in c_grad], [None, None])
- @test_util.disable_c_api # Operation._set_device doesn't work with C API
def testConsistent(self):
nums, divs = self.intTestData()
with self.test_session():