diff options
author | 2016-10-31 11:00:05 -0800 | |
---|---|---|
committer | 2016-10-31 12:07:05 -0700 | |
commit | c5ccfe7e1b34ccc648a27bbf401c2a68568dde3a (patch) | |
tree | aac8c6bc6c188b5fda859ecb24acbe4219899f6e | |
parent | 1962804adc32d9bbdf0512b968c32e4cd86ae791 (diff) |
Changes the simple placer to be aware of resource handles.
Change: 137730576
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer_test.cc | 59 |
2 files changed, 61 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 617151ad7d..fda429b52a 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -645,7 +645,8 @@ Status SimplePlacer::Run() { // edge from the source of that edge to `node`. for (const auto& edge : node->in_edges()) { if (!edge->IsControlEdge() && - IsRefType(node->input_type(edge->dst_input()))) { + (IsRefType(node->input_type(edge->dst_input())) || + node->input_type(edge->dst_input()) == DT_RESOURCE)) { // If both the source node and this node have paritally // specified a device, then 'node's device should be // cleared: the reference edge forces 'node' to be on the diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index 148fc973dd..0c4edc28ec 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -472,6 +472,65 @@ TEST_F(SimplePlacerTest, TestReferenceConnection) { TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", DEVICE_GPU)); } +// Handle-using dummy variable ops. +REGISTER_OP("TestHandleVariable").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("HandleVariableCPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device(DEVICE_CPU), DummyOp); + +REGISTER_OP("HandleVariableGPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device(DEVICE_CPU), DummyOp); + +REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device(DEVICE_GPU), DummyOp); + +// Tests all combinations of resource handles and ops using them. +TEST_F(SimplePlacerTest, TestResourceHandle) { + auto handle_test = [this](const string& var_op_name, + const string& use_op_name, DeviceType device) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var")); + ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_RETURN_IF_ERROR(Place(&g)); + + EXPECT_COLOCATED(g, "var", "assign"); + EXPECT_DEVICE_TYPE(g, "var", device); + EXPECT_DEVICE_TYPE(g, "assign", device); + return Status::OK(); + }; + TF_EXPECT_OK( + handle_test("TestHandleVariable", "TestHandleAssign", DEVICE_GPU)); + TF_EXPECT_OK( + handle_test("TestHandleVariable", "HandleAssignCPU", DEVICE_CPU)); + TF_EXPECT_OK( + handle_test("TestHandleVariable", "HandleAssignGPU", DEVICE_GPU)); + TF_EXPECT_OK( + handle_test("HandleVariableCPU", "TestHandleAssign", DEVICE_CPU)); + TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", DEVICE_CPU)); + TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", DEVICE_GPU)); + TF_EXPECT_OK( + handle_test("HandleVariableGPU", "TestHandleAssign", DEVICE_GPU)); + EXPECT_FALSE( + handle_test("HandleVariableGPU", "HandleAssignCPU", DEVICE_CPU).ok()); + EXPECT_FALSE( + handle_test("HandleVariableCPU", "HandleAssignGPU", DEVICE_CPU).ok()); +} + // Test that an assignment of an operator to the wrong device // is ignored when it could never be satisfied (due to reference // edges, for example). |