diff options
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 48 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.h | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler_test.cc | 132 |
3 files changed, 172 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 15ebef188f..8b51bb9096 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -56,6 +56,28 @@ Costs CombineCosts(const Costs& left, const Costs& right) { << " max_per_op_streaming=" << result.max_per_op_streaming; return result; } + +// Key to the cached _Recv ops map, and its hash and predicate structures. +struct RecvNodeDescriptor { + const NodeDef* node; + const int port_num; + const string& device; +}; + +struct RecvNodeDescritorHash { + std::size_t operator()(const RecvNodeDescriptor& recv_node) const { + return std::hash<const NodeDef*>()(recv_node.node) ^ + std::hash<int>()(recv_node.port_num) ^ + std::hash<string>()(recv_node.device); + } +}; + +struct RecvNodeDescriptorEqual { + bool operator()(const RecvNodeDescriptor& a, + const RecvNodeDescriptor& b) const { + return a.node == b.node && a.port_num == b.port_num && a.device == b.device; + } +}; } // namespace VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, @@ -109,6 +131,11 @@ Status VirtualScheduler::Init() { name_to_node[node->name()] = node; } + // To reuse _Recv ops. + std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash, + RecvNodeDescriptorEqual> + cached_recv_nodes; + // Build node_map; for each node, create its NodeState and connect its inputs // and outputs. for (const auto* curr_node : nodes) { @@ -131,12 +158,14 @@ Status VirtualScheduler::Init() { auto& input_node_state = GetNodeStateOrCreateIt(input_node); input_node_state.outputs[input_node_port_num].push_back(curr_node); } else { - if (cached_recv_nodes_.count(input_node) > 0 && - cached_recv_nodes_[input_node].count(curr_node_device) > 0) { + RecvNodeDescriptor recv_node = {.node = input_node, + .port_num = input_node_port_num, + .device = curr_node_device}; + auto it = cached_recv_nodes.find(recv_node); + if (it != cached_recv_nodes.end()) { // Different device, but found an already-cached copy (a _Recv op); // connect the _Recv to curr_node. - const auto* recv_op = - cached_recv_nodes_[input_node][curr_node_device]; + const NodeDef* recv_op = it->second; // recv_op's output port is hard-coded to zero. curr_node_state.inputs.push_back(std::make_pair(recv_op, 0)); auto& input_node_state = node_map_.at(recv_op); @@ -156,7 +185,7 @@ Status VirtualScheduler::Init() { input_node_state.outputs[input_node_port_num].push_back(send); // Cache the _Recv op for future use. - cached_recv_nodes_[input_node][curr_node_device] = recv; + cached_recv_nodes[recv_node] = recv; } } } @@ -269,11 +298,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( // input names, attrs, etc. auto input_node_port_num = NodePosition(input_name); + const string port_num_string = + input_node_port_num >= 0 ? std::to_string(input_node_port_num) : "minus1"; // _Send op. auto* send = new NodeDef(); - send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " + - DeviceName(to)); + send->set_name("Send " + from->name() + ":" + port_num_string + " from " + + DeviceName(from) + " to " + DeviceName(to)); send->set_op("_Send"); send->add_input(from->name()); send->set_device(ChannelDeviceName(from, to)); @@ -284,7 +315,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( // _Recv op. auto* recv = new NodeDef(); - recv->set_name("Recv " + from->name() + " on " + DeviceName(to)); + recv->set_name("Recv " + from->name() + ":" + port_num_string + " on " + + DeviceName(to)); recv->set_op("_Recv"); recv->add_input(send->name()); recv->set_device(DeviceName(to)); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index d00766d9fa..e9abecb122 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -254,9 +254,6 @@ class VirtualScheduler { // Pool of NodeDefs for SendRecv and Identity ops created. std::vector<std::unique_ptr<NodeDef>> additional_nodes_; - // Cache of nodes transferred to another device. - std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>> - cached_recv_nodes_; // Stats: std::map<string, int> op_counts_; // Op counts with key with input shape. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 9743db33db..29e3db1b74 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -36,6 +36,7 @@ class TestVirtualScheduler : public VirtualScheduler { FRIEND_TEST(VirtualSchedulerTest, ControlDependency); FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); FRIEND_TEST(VirtualSchedulerTest, Variable); + FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer); }; class VirtualSchedulerTest : public ::testing::Test { @@ -43,6 +44,7 @@ class VirtualSchedulerTest : public ::testing::Test { NodeDef node1_, node2_, node3_, node4_, node5_, node6_; const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0"; + const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1"; DeviceProperties GetDummyCPUDevice() { // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth. @@ -74,6 +76,7 @@ class VirtualSchedulerTest : public ::testing::Test { // IMPORTANT: Device is not actually ever used in the test case since // force_cpu_type is defaulted to "Haswell" devices[kCPU0] = cpu_device; + devices[kCPU1] = cpu_device; cluster_.reset(new VirtualCluster(devices)); placer_.reset(new VirtualPlacer(cluster_.get())); } @@ -255,6 +258,64 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["z4"] = {"bn"}; } + // Create a fused bathcnorm on one device, and send the outputs to another + // device. + void CreateGrapplerItemWithInterDeviceTransfers() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + + // Create a FusedBatchNorm op that has multiple output ports. + auto x = tensorflow::ops::RandomUniform( + s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); + auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"), + {depth_in_}, DT_FLOAT); + auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"), + {depth_in_}, DT_FLOAT); + auto mean = + tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); + auto var = + tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); + + auto batch_norm = tensorflow::ops::FusedBatchNorm( + s.WithOpName("bn"), x, scale, offset, mean, var, + ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); + auto y = batch_norm.y; + auto batch_mean = batch_norm.batch_mean; + auto batch_var = batch_norm.batch_variance; + + // Copy FusedBatchNorm's outputs to CPU1. + // y1 and y2 take the same tensor, so there should be only 1 Send and Recv. + auto y1 = + tensorflow::ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y); + auto y2 = + tensorflow::ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y); + // batch_mean1 and batch_var1 take different output ports, so each will + // initiate Send/Recv. + auto batch_mean1 = tensorflow::ops::Identity( + s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean); + auto batch_var1 = tensorflow::ops::Identity( + s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var); + // This is control dependency. + auto control_dep = tensorflow::ops::NoOp(s.WithOpName("control_dep") + .WithControlDependencies(y) + .WithDevice(kCPU1)); + + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_conv2d_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1", + "control_dep"}; + + dependency_["bn"] = {"x", "mean", "var"}; + dependency_["y1"] = {"bn"}; + dependency_["y2"] = {"bn"}; + dependency_["batch_mean1"] = {"bn"}; + dependency_["batch_var1"] = {"bn"}; + dependency_["control_dep"] = {"bn"}; + } + // Call this after creating grappler_item_ and setting up dependency_. void InitScheduler() { scheduler_.reset(new TestVirtualScheduler( @@ -803,5 +864,76 @@ TEST_F(VirtualSchedulerTest, Variable) { ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */, cpu_state.mem_usage_snapshot_at_peak); } + +TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { + // Init. + CreateGrapplerItemWithInterDeviceTransfers(); + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + // Helper lambda to extract port num from _Send and _Recv op name. + auto get_port_num = [](const string& name) -> int { + if (name.find("bn:0") != std::string::npos) { + return 0; + } else if (name.find("bn:1") != std::string::npos) { + return 1; + } else if (name.find("bn:2") != std::string::npos) { + return 2; + } else if (name.find("bn:minus1") != std::string::npos) { + return -1; + } + return -999; + }; + + // Reorganize ops_executed for further testing. + std::unordered_map<string, int> op_count; + std::unordered_map<int, string> recv_op_names; + std::unordered_map<int, string> send_op_names; + for (const auto& x : ops_executed) { + const auto& name = x.first; + const auto& node_info = x.second; + const auto& op = node_info.op_info.op(); + if (op == "_Recv") { + recv_op_names[get_port_num(name)] = name; + } else if (op == "_Send") { + send_op_names[get_port_num(name)] = name; + } + op_count[op]++; + } + + // Same number of _Send and _Recv. + EXPECT_EQ(op_count.at("_Send"), op_count.at("_Recv")); + + // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency. + EXPECT_EQ(op_count.at("_Recv"), 4); + EXPECT_EQ(op_count.at("_Send"), 4); + + // Helper lambda for extracting output Tensor size. + auto get_output_size = [this, ops_executed](const string& name) -> int64 { + const auto& output_properties_ = ops_executed.at(name).op_info.outputs(); + std::vector<OpInfo::TensorProperties> output_properties; + for (const auto& output_property : output_properties_) { + output_properties.push_back(output_property); + } + return scheduler_->CalculateOutputSize(output_properties, 0); + + }; + + // Validate transfer size. + // Batchnorm output y is 4D vector: batch x width x width x depth. + int input_size = 4 * batch_size_ * width_ * height_ * depth_in_; + EXPECT_EQ(get_output_size(recv_op_names[0]), input_size); + EXPECT_EQ(get_output_size(send_op_names[0]), input_size); + // Mean and vars are 1-D vector with size depth_in_. + EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_); + EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_); + EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_); + EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_); + // Control dependency size is 4B. + EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); + EXPECT_EQ(get_output_size(send_op_names[-1]), 4); +} } // end namespace grappler } // end namespace tensorflow |