aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc29
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc171
-rw-r--r--tensorflow/core/grappler/grappler_item.cc15
3 files changed, 212 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index d5625ae58f..2ab3a9144c 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -154,6 +154,16 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
+ // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
+ // to _Recv as control dependency when creating GrapplerItem.
+ std::unordered_map<string, const NodeDef*> name_to_send;
+ for (const auto& node : graph.node()) {
+ if (node.op() == "_Send") {
+ const auto& attr = node.attr();
+ name_to_send[attr.at("tensor_name").s()] = &node;
+ }
+ }
+
// To reuse _Recv ops.
std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
RecvNodeDescriptorEqual>
@@ -164,7 +174,17 @@ Status VirtualScheduler::Init() {
for (const auto* curr_node : nodes) {
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
const string curr_node_device = DeviceName(curr_node);
- for (const string& input_node_name : curr_node->input()) {
+ std::vector<string> inputs;
+ if (IsRecv(*curr_node)) {
+ const auto& attr = curr_node->attr();
+ const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
+ inputs = {send->name()};
+ } else {
+ for (const string& input : curr_node->input()) {
+ inputs.push_back(input);
+ }
+ }
+ for (const string& input_node_name : inputs) {
// Note that input_node_name may be in <prefix><node_name>:<port_num>
// format, where <prefix> (e.g., "^" for control dependency) and
// ":<port_num>" may be omitted. NodeName() extracts only the node_name.
@@ -219,7 +239,7 @@ Status VirtualScheduler::Init() {
// Default case: node without inputs are ready at time 0.
const bool has_no_inputs = curr_node->input().empty();
- if (given_as_feed || has_no_inputs) {
+ if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) {
curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
VLOG(3) << "Added ready node: " << curr_node->name();
@@ -254,7 +274,10 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
// This method is called when NodeState is created and adds input and output
// properties for a few exceptional cases that GraphProperties cannot provide
// input/output properties.
- if (IsSend(*node) || IsRecv(*node)) {
+ if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
+ // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
+ // attr; normal _Send and _Recv ops (from the input graph) do not have that
+ // attr.
auto& node_state = node_map_[node];
auto& inputs = node_state.input_properties;
auto& outputs = node_state.output_properties;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index d291a04308..40548b5a07 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -265,6 +265,127 @@ class VirtualSchedulerTest : public ::testing::Test {
dependency_["z4"] = {"bn"};
}
+ void CreateGrapplerItemWithSendRecv() {
+ const string gdef_ascii = R"EOF(
+node {
+ name: "Const"
+ op: "Const"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 3.1415
+ }
+ }
+ }
+}
+node {
+ name: "Send"
+ op: "_Send"
+ input: "Const"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "client_terminated"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "recv_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device_incarnation"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "tensor_name"
+ value {
+ s: "test"
+ }
+ }
+}
+node {
+ name: "Recv"
+ op: "_Recv"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "client_terminated"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "recv_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device_incarnation"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "tensor_name"
+ value {
+ s: "test"
+ }
+ }
+ attr {
+ key: "tensor_type"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+library {
+}
+versions {
+ producer: 24
+}
+ )EOF";
+
+ grappler_item_.reset(new GrapplerItem);
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
+ &grappler_item_->graph));
+ grappler_item_->id = "test_graph";
+ grappler_item_->fetch = {"Recv"};
+ }
+
// A simple while loop
void CreateGrapplerItemWithLoop() {
// Test graph produced in python using:
@@ -743,6 +864,7 @@ versions {
do {
OpContext op_context = scheduler_->GetCurrNode();
ops_executed[op_context.name] = op_context;
+ std::cout << op_context.name << std::endl;
Costs node_costs = SimplePredictCosts(op_context);
@@ -1530,5 +1652,54 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
}
+
+TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
+ // Init.
+ CreateGrapplerItemWithSendRecv();
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler("");
+
+ EXPECT_GT(ops_executed.count("Const"), 0);
+ EXPECT_GT(ops_executed.count("Send"), 0);
+ EXPECT_GT(ops_executed.count("Recv"), 0);
+}
+
+TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
+ // Init.
+ CreateGrapplerItemWithSendRecv();
+ // Change Recv node's device so that Send and Recv are placed on different
+ // devices.
+ auto& graph = grappler_item_->graph;
+ const string recv_device = kCPU1;
+ for (int i = 0; i < graph.node_size(); i++) {
+ auto* node = graph.mutable_node(i);
+ if (node->name() == "Recv") {
+ node->set_device(recv_device);
+ auto* attr = node->mutable_attr();
+ (*attr)["recv_device"].set_s(recv_device);
+ } else if (node->name() == "Send") {
+ auto* attr = node->mutable_attr();
+ (*attr)["recv_device"].set_s(recv_device);
+ }
+ }
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler("");
+
+ // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
+ EXPECT_GT(ops_executed.count("Const"), 0);
+ EXPECT_GT(ops_executed.count("Send"), 0);
+ EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
+ "task_0/cpu_0_to_/job_localhost"
+ "/replica_0/task_0/cpu_1"),
+ 0);
+ EXPECT_GT(ops_executed.count(
+ "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
+ 0);
+ EXPECT_GT(ops_executed.count("Recv"), 0);
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 94412eb198..844a1fa328 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
@@ -117,8 +118,13 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
bool* ill_formed) {
*ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_node;
+ std::unordered_map<string, const NodeDef*> name_to_send;
for (const auto& node : graph.node()) {
name_to_node[node.name()] = &node;
+ if (node.op() == "_Send") {
+ const auto& attr = node.attr();
+ name_to_send[attr.at("tensor_name").s()] = &node;
+ }
}
std::vector<const NodeDef*> queue;
@@ -150,6 +156,15 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
}
queue.push_back(in);
}
+ if (node->op() == "_Recv") {
+ const auto& attr = node->attr();
+ const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
+ if (send) {
+ queue.push_back(send);
+ }
+ // Subgraph after partitioning may have either _Send or _Recv, not both.
+ // So, we do not set ill_formed for missing _Send.
+ }
}
return result;
}