diff options
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 29 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler_test.cc | 171 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item.cc | 15 |
3 files changed, 212 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index d5625ae58f..2ab3a9144c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -154,6 +154,16 @@ Status VirtualScheduler::Init() { name_to_node[node->name()] = node; } + // TODO(dyoon): Instead of identifying _Send node here manually, add _Send + // to _Recv as control dependency when creating GrapplerItem. + std::unordered_map<string, const NodeDef*> name_to_send; + for (const auto& node : graph.node()) { + if (node.op() == "_Send") { + const auto& attr = node.attr(); + name_to_send[attr.at("tensor_name").s()] = &node; + } + } + // To reuse _Recv ops. std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash, RecvNodeDescriptorEqual> @@ -164,7 +174,17 @@ Status VirtualScheduler::Init() { for (const auto* curr_node : nodes) { auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); const string curr_node_device = DeviceName(curr_node); - for (const string& input_node_name : curr_node->input()) { + std::vector<string> inputs; + if (IsRecv(*curr_node)) { + const auto& attr = curr_node->attr(); + const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; + inputs = {send->name()}; + } else { + for (const string& input : curr_node->input()) { + inputs.push_back(input); + } + } + for (const string& input_node_name : inputs) { // Note that input_node_name may be in <prefix><node_name>:<port_num> // format, where <prefix> (e.g., "^" for control dependency) and // ":<port_num>" may be omitted. NodeName() extracts only the node_name. @@ -219,7 +239,7 @@ Status VirtualScheduler::Init() { // Default case: node without inputs are ready at time 0. const bool has_no_inputs = curr_node->input().empty(); - if (given_as_feed || has_no_inputs) { + if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); VLOG(3) << "Added ready node: " << curr_node->name(); @@ -254,7 +274,10 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { // This method is called when NodeState is created and adds input and output // properties for a few exceptional cases that GraphProperties cannot provide // input/output properties. - if (IsSend(*node) || IsRecv(*node)) { + if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { + // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc + // attr; normal _Send and _Recv ops (from the input graph) do not have that + // attr. auto& node_state = node_map_[node]; auto& inputs = node_state.input_properties; auto& outputs = node_state.output_properties; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index d291a04308..40548b5a07 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -265,6 +265,127 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["z4"] = {"bn"}; } + void CreateGrapplerItemWithSendRecv() { + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.1415 + } + } + } +} +node { + name: "Send" + op: "_Send" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: false + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: 0 + } + } + attr { + key: "tensor_name" + value { + s: "test" + } + } +} +node { + name: "Recv" + op: "_Recv" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "client_terminated" + value { + b: false + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: 0 + } + } + attr { + key: "tensor_name" + value { + s: "test" + } + } + attr { + key: "tensor_type" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { + producer: 24 +} + )EOF"; + + grappler_item_.reset(new GrapplerItem); + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, + &grappler_item_->graph)); + grappler_item_->id = "test_graph"; + grappler_item_->fetch = {"Recv"}; + } + // A simple while loop void CreateGrapplerItemWithLoop() { // Test graph produced in python using: @@ -743,6 +864,7 @@ versions { do { OpContext op_context = scheduler_->GetCurrNode(); ops_executed[op_context.name] = op_context; + std::cout << op_context.name << std::endl; Costs node_costs = SimplePredictCosts(op_context); @@ -1530,5 +1652,54 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); EXPECT_EQ(get_output_size(send_op_names[-1]), 4); } + +TEST_F(VirtualSchedulerTest, GraphWithSendRecv) { + // Init. + CreateGrapplerItemWithSendRecv(); + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + EXPECT_GT(ops_executed.count("Const"), 0); + EXPECT_GT(ops_executed.count("Send"), 0); + EXPECT_GT(ops_executed.count("Recv"), 0); +} + +TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) { + // Init. + CreateGrapplerItemWithSendRecv(); + // Change Recv node's device so that Send and Recv are placed on different + // devices. + auto& graph = grappler_item_->graph; + const string recv_device = kCPU1; + for (int i = 0; i < graph.node_size(); i++) { + auto* node = graph.mutable_node(i); + if (node->name() == "Recv") { + node->set_device(recv_device); + auto* attr = node->mutable_attr(); + (*attr)["recv_device"].set_s(recv_device); + } else if (node->name() == "Send") { + auto* attr = node->mutable_attr(); + (*attr)["recv_device"].set_s(recv_device); + } + } + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops. + EXPECT_GT(ops_executed.count("Const"), 0); + EXPECT_GT(ops_executed.count("Send"), 0); + EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/" + "task_0/cpu_0_to_/job_localhost" + "/replica_0/task_0/cpu_1"), + 0); + EXPECT_GT(ops_executed.count( + "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"), + 0); + EXPECT_GT(ops_executed.count("Recv"), 0); +} } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 94412eb198..844a1fa328 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -19,6 +19,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -117,8 +118,13 @@ std::vector<const NodeDef*> ComputeTransitiveFanin( bool* ill_formed) { *ill_formed = false; std::unordered_map<string, const NodeDef*> name_to_node; + std::unordered_map<string, const NodeDef*> name_to_send; for (const auto& node : graph.node()) { name_to_node[node.name()] = &node; + if (node.op() == "_Send") { + const auto& attr = node.attr(); + name_to_send[attr.at("tensor_name").s()] = &node; + } } std::vector<const NodeDef*> queue; @@ -150,6 +156,15 @@ std::vector<const NodeDef*> ComputeTransitiveFanin( } queue.push_back(in); } + if (node->op() == "_Recv") { + const auto& attr = node->attr(); + const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; + if (send) { + queue.push_back(send); + } + // Subgraph after partitioning may have either _Send or _Recv, not both. + // So, we do not set ill_formed for missing _Send. + } } return result; } |