aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 11:00:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 12:07:05 -0700
commitc5ccfe7e1b34ccc648a27bbf401c2a68568dde3a (patch)
treeaac8c6bc6c188b5fda859ecb24acbe4219899f6e
parent1962804adc32d9bbdf0512b968c32e4cd86ae791 (diff)
Changes the simple placer to be aware of resource handles.
Change: 137730576
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc3
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc59
2 files changed, 61 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index 617151ad7d..fda429b52a 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -645,7 +645,8 @@ Status SimplePlacer::Run() {
// edge from the source of that edge to `node`.
for (const auto& edge : node->in_edges()) {
if (!edge->IsControlEdge() &&
- IsRefType(node->input_type(edge->dst_input()))) {
+ (IsRefType(node->input_type(edge->dst_input())) ||
+ node->input_type(edge->dst_input()) == DT_RESOURCE)) {
// If both the source node and this node have paritally
// specified a device, then 'node's device should be
// cleared: the reference edge forces 'node' to be on the
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index 148fc973dd..0c4edc28ec 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -472,6 +472,65 @@ TEST_F(SimplePlacerTest, TestReferenceConnection) {
TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", DEVICE_GPU));
}
+// Handle-using dummy variable ops.
+REGISTER_OP("TestHandleVariable").Output("o: resource");
+REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("HandleVariableCPU").Output("o: resource");
+REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("HandleVariableGPU").Output("o: resource");
+REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device(DEVICE_GPU), DummyOp);
+
+// Tests all combinations of resource handles and ops using them.
+TEST_F(SimplePlacerTest, TestResourceHandle) {
+ auto handle_test = [this](const string& var_op_name,
+ const string& use_op_name, DeviceType device) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var"));
+ ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_RETURN_IF_ERROR(Place(&g));
+
+ EXPECT_COLOCATED(g, "var", "assign");
+ EXPECT_DEVICE_TYPE(g, "var", device);
+ EXPECT_DEVICE_TYPE(g, "assign", device);
+ return Status::OK();
+ };
+ TF_EXPECT_OK(
+ handle_test("TestHandleVariable", "TestHandleAssign", DEVICE_GPU));
+ TF_EXPECT_OK(
+ handle_test("TestHandleVariable", "HandleAssignCPU", DEVICE_CPU));
+ TF_EXPECT_OK(
+ handle_test("TestHandleVariable", "HandleAssignGPU", DEVICE_GPU));
+ TF_EXPECT_OK(
+ handle_test("HandleVariableCPU", "TestHandleAssign", DEVICE_CPU));
+ TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", DEVICE_CPU));
+ TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", DEVICE_GPU));
+ TF_EXPECT_OK(
+ handle_test("HandleVariableGPU", "TestHandleAssign", DEVICE_GPU));
+ EXPECT_FALSE(
+ handle_test("HandleVariableGPU", "HandleAssignCPU", DEVICE_CPU).ok());
+ EXPECT_FALSE(
+ handle_test("HandleVariableCPU", "HandleAssignGPU", DEVICE_CPU).ok());
+}
+
// Test that an assignment of an operator to the wrong device
// is ignored when it could never be satisfied (due to reference
// edges, for example).