aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc48
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h3
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc132
3 files changed, 172 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 15ebef188f..8b51bb9096 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -56,6 +56,28 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
<< " max_per_op_streaming=" << result.max_per_op_streaming;
return result;
}
+
+// Key to the cached _Recv ops map, and its hash and predicate structures.
+struct RecvNodeDescriptor {
+ const NodeDef* node;
+ const int port_num;
+ const string& device;
+};
+
+struct RecvNodeDescritorHash {
+ std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
+ return std::hash<const NodeDef*>()(recv_node.node) ^
+ std::hash<int>()(recv_node.port_num) ^
+ std::hash<string>()(recv_node.device);
+ }
+};
+
+struct RecvNodeDescriptorEqual {
+ bool operator()(const RecvNodeDescriptor& a,
+ const RecvNodeDescriptor& b) const {
+ return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
+ }
+};
} // namespace
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
@@ -109,6 +131,11 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
+ // To reuse _Recv ops.
+ std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
+ RecvNodeDescriptorEqual>
+ cached_recv_nodes;
+
// Build node_map; for each node, create its NodeState and connect its inputs
// and outputs.
for (const auto* curr_node : nodes) {
@@ -131,12 +158,14 @@ Status VirtualScheduler::Init() {
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
input_node_state.outputs[input_node_port_num].push_back(curr_node);
} else {
- if (cached_recv_nodes_.count(input_node) > 0 &&
- cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
+ RecvNodeDescriptor recv_node = {.node = input_node,
+ .port_num = input_node_port_num,
+ .device = curr_node_device};
+ auto it = cached_recv_nodes.find(recv_node);
+ if (it != cached_recv_nodes.end()) {
// Different device, but found an already-cached copy (a _Recv op);
// connect the _Recv to curr_node.
- const auto* recv_op =
- cached_recv_nodes_[input_node][curr_node_device];
+ const NodeDef* recv_op = it->second;
// recv_op's output port is hard-coded to zero.
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
auto& input_node_state = node_map_.at(recv_op);
@@ -156,7 +185,7 @@ Status VirtualScheduler::Init() {
input_node_state.outputs[input_node_port_num].push_back(send);
// Cache the _Recv op for future use.
- cached_recv_nodes_[input_node][curr_node_device] = recv;
+ cached_recv_nodes[recv_node] = recv;
}
}
}
@@ -269,11 +298,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// input names, attrs, etc.
auto input_node_port_num = NodePosition(input_name);
+ const string port_num_string =
+ input_node_port_num >= 0 ? std::to_string(input_node_port_num) : "minus1";
// _Send op.
auto* send = new NodeDef();
- send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
- DeviceName(to));
+ send->set_name("Send " + from->name() + ":" + port_num_string + " from " +
+ DeviceName(from) + " to " + DeviceName(to));
send->set_op("_Send");
send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to));
@@ -284,7 +315,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// _Recv op.
auto* recv = new NodeDef();
- recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
+ recv->set_name("Recv " + from->name() + ":" + port_num_string + " on " +
+ DeviceName(to));
recv->set_op("_Recv");
recv->add_input(send->name());
recv->set_device(DeviceName(to));
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index d00766d9fa..e9abecb122 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -254,9 +254,6 @@ class VirtualScheduler {
// Pool of NodeDefs for SendRecv and Identity ops created.
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
- // Cache of nodes transferred to another device.
- std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
- cached_recv_nodes_;
// Stats:
std::map<string, int> op_counts_; // Op counts with key with input shape.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 9743db33db..29e3db1b74 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -36,6 +36,7 @@ class TestVirtualScheduler : public VirtualScheduler {
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
FRIEND_TEST(VirtualSchedulerTest, Variable);
+ FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
};
class VirtualSchedulerTest : public ::testing::Test {
@@ -43,6 +44,7 @@ class VirtualSchedulerTest : public ::testing::Test {
NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
+ const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
DeviceProperties GetDummyCPUDevice() {
// Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
@@ -74,6 +76,7 @@ class VirtualSchedulerTest : public ::testing::Test {
// IMPORTANT: Device is not actually ever used in the test case since
// force_cpu_type is defaulted to "Haswell"
devices[kCPU0] = cpu_device;
+ devices[kCPU1] = cpu_device;
cluster_.reset(new VirtualCluster(devices));
placer_.reset(new VirtualPlacer(cluster_.get()));
}
@@ -255,6 +258,64 @@ class VirtualSchedulerTest : public ::testing::Test {
dependency_["z4"] = {"bn"};
}
+ // Create a fused bathcnorm on one device, and send the outputs to another
+ // device.
+ void CreateGrapplerItemWithInterDeviceTransfers() {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+
+ // Create a FusedBatchNorm op that has multiple output ports.
+ auto x = tensorflow::ops::RandomUniform(
+ s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
+ auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
+ {depth_in_}, DT_FLOAT);
+ auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
+ {depth_in_}, DT_FLOAT);
+ auto mean =
+ tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
+ auto var =
+ tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
+
+ auto batch_norm = tensorflow::ops::FusedBatchNorm(
+ s.WithOpName("bn"), x, scale, offset, mean, var,
+ ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
+ auto y = batch_norm.y;
+ auto batch_mean = batch_norm.batch_mean;
+ auto batch_var = batch_norm.batch_variance;
+
+ // Copy FusedBatchNorm's outputs to CPU1.
+ // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
+ auto y1 =
+ tensorflow::ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
+ auto y2 =
+ tensorflow::ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
+ // batch_mean1 and batch_var1 take different output ports, so each will
+ // initiate Send/Recv.
+ auto batch_mean1 = tensorflow::ops::Identity(
+ s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
+ auto batch_var1 = tensorflow::ops::Identity(
+ s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
+ // This is control dependency.
+ auto control_dep = tensorflow::ops::NoOp(s.WithOpName("control_dep")
+ .WithControlDependencies(y)
+ .WithDevice(kCPU1));
+
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+
+ grappler_item_.reset(new GrapplerItem);
+ grappler_item_->id = "test_conv2d_graph";
+ grappler_item_->graph = def;
+ grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
+ "control_dep"};
+
+ dependency_["bn"] = {"x", "mean", "var"};
+ dependency_["y1"] = {"bn"};
+ dependency_["y2"] = {"bn"};
+ dependency_["batch_mean1"] = {"bn"};
+ dependency_["batch_var1"] = {"bn"};
+ dependency_["control_dep"] = {"bn"};
+ }
+
// Call this after creating grappler_item_ and setting up dependency_.
void InitScheduler() {
scheduler_.reset(new TestVirtualScheduler(
@@ -803,5 +864,76 @@ TEST_F(VirtualSchedulerTest, Variable) {
ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
cpu_state.mem_usage_snapshot_at_peak);
}
+
+TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
+ // Init.
+ CreateGrapplerItemWithInterDeviceTransfers();
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler("");
+
+ // Helper lambda to extract port num from _Send and _Recv op name.
+ auto get_port_num = [](const string& name) -> int {
+ if (name.find("bn:0") != std::string::npos) {
+ return 0;
+ } else if (name.find("bn:1") != std::string::npos) {
+ return 1;
+ } else if (name.find("bn:2") != std::string::npos) {
+ return 2;
+ } else if (name.find("bn:minus1") != std::string::npos) {
+ return -1;
+ }
+ return -999;
+ };
+
+ // Reorganize ops_executed for further testing.
+ std::unordered_map<string, int> op_count;
+ std::unordered_map<int, string> recv_op_names;
+ std::unordered_map<int, string> send_op_names;
+ for (const auto& x : ops_executed) {
+ const auto& name = x.first;
+ const auto& node_info = x.second;
+ const auto& op = node_info.op_info.op();
+ if (op == "_Recv") {
+ recv_op_names[get_port_num(name)] = name;
+ } else if (op == "_Send") {
+ send_op_names[get_port_num(name)] = name;
+ }
+ op_count[op]++;
+ }
+
+ // Same number of _Send and _Recv.
+ EXPECT_EQ(op_count.at("_Send"), op_count.at("_Recv"));
+
+ // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
+ EXPECT_EQ(op_count.at("_Recv"), 4);
+ EXPECT_EQ(op_count.at("_Send"), 4);
+
+ // Helper lambda for extracting output Tensor size.
+ auto get_output_size = [this, ops_executed](const string& name) -> int64 {
+ const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
+ std::vector<OpInfo::TensorProperties> output_properties;
+ for (const auto& output_property : output_properties_) {
+ output_properties.push_back(output_property);
+ }
+ return scheduler_->CalculateOutputSize(output_properties, 0);
+
+ };
+
+ // Validate transfer size.
+ // Batchnorm output y is 4D vector: batch x width x width x depth.
+ int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
+ EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
+ EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
+ // Mean and vars are 1-D vector with size depth_in_.
+ EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
+ EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
+ EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
+ EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
+ // Control dependency size is 4B.
+ EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
+ EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
+}
} // end namespace grappler
} // end namespace tensorflow