From 5f308cb408eb46ec9af0546be6b9ae1d5166b185 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Oct 2018 09:06:04 -0700 Subject: Optimize PinToHostOptimizer by adding cache, also add PinToHostOptimizer to benchmarks. original runtime: 4.83492736816 secs w/ cache runtime: 2.19033999443 secs PiperOrigin-RevId: 216195286 --- tensorflow/core/grappler/op_types.cc | 22 +-- .../grappler/optimizers/pin_to_host_optimizer.cc | 162 ++++++++++++++------- .../grappler/optimizers/pin_to_host_optimizer.h | 4 +- .../optimizers/pin_to_host_optimizer_test.cc | 76 +++++++--- 4 files changed, 179 insertions(+), 85 deletions(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 1b5a215987..cbf5c8e038 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) { } bool IsControlFlow(const NodeDef& node) { - // clang-format off - return node.op() == "ControlTrigger" || - node.op() == "Enter" || - node.op() == "Exit" || - node.op() == "LoopCond" || - node.op() == "Merge" || - node.op() == "NextIteration" || - node.op() == "Switch"; - // clang-format on + // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative + // string comparison. + static const gtl::FlatSet* const kControFlowOps = + CHECK_NOTNULL((new gtl::FlatSet{ + "ControlTrigger", + "Enter", + "Exit", + "LoopCond", + "Merge", + "NextIteration", + "Switch", + })); + return kControFlowOps->count(node.op()) > 0; } bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 8ed4271fa4..29a3b2b74c 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -25,16 +25,29 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { + namespace internal { +namespace { // TODO(williamchan): Change this constant to be something smarter, maybe // dynamically determined. constexpr int64 kTensorMaxSize = 64; +struct OpDevicePortHasher { + std::size_t operator()(const std::tuple& x) const { + uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x))); + + return Hash64Combine(code, hash()(std::get<2>(x))); + } +}; +using OpDevicePortOnHostMap = + gtl::FlatMap, bool, OpDevicePortHasher>; + // All the nodes that should be blacklisted and not swapped. bool IsBlacklisted(const NodeDef& node) { return @@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector& devices, // Checks if a node's output port is host friendly. // Roughly this means checking if the output port is on Host memory. -Status IsNodeOutputPortHostFriendly(const GraphView& graph, - GraphProperties* properties, - const NodeDef& node, int port_id, - bool* is_candidate) { +Status IsNodeOutputPortHostFriendly( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { *is_candidate = false; // Make sure we are not a blacklisted op. @@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, for (const auto& fanin : graph.GetFanins(node, false)) { bool fanin_candidate = false; TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); if (!fanin_candidate) { return Status::OK(); } @@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, return Status::OK(); } + // Check `op_device_outport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + const std::tuple cache_key(node.op(), node.device(), + port_id); + auto it = op_device_outport_pinned_to_host_cache->find(cache_key); + if (it != op_device_outport_pinned_to_host_cache->end()) { + *is_candidate = it->second; + return Status::OK(); + } + // Check if op's output port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, LOG(WARNING) << "Invalid port: " << port_id << "!\n" << node.DebugString() << "\n" << op->DebugString(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, } } + op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate); + return Status::OK(); } // Checks if a node's input port is Host friendly. // Roughly this means checking if the input port is on Host memory. -bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { +bool IsNodeInputPortHostFriendly( + const NodeDef& node, int port_id, + OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) { // If node is on Host, assume its inputs are Host friendly. if (str_util::StrContains(node.device(), DEVICE_CPU)) { return true; } + // Check `op_device_inport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + std::tuple cache_key(node.op(), node.device(), port_id); + auto it = op_device_inport_pinned_to_host_cache->find(cache_key); + if (it != op_device_inport_pinned_to_host_cache->end()) { + return it->second; + } + // Check if op's input port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); @@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } // Check if the input_arg is pinned to Host. for (const string& host_memory_arg : kernel->host_memory_arg()) { if (op->input_arg(input_arg_id).name() == host_memory_arg) { + op_device_inport_pinned_to_host_cache->emplace(cache_key, true); return true; } } + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + return false; } @@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { // 2] Check if node can run on Host. // 3] Check all input/outputs are Host "friendly" (atm, friendly means small, // ints, and pinned to Host). -Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, - const NodeDef& node, bool* is_candidate) { +Status IsNodeHostCandidate( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { *is_candidate = false; - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - *is_candidate = true; + // Skip these node types. + if (IsBlacklisted(node)) { return Status::OK(); } - // Skip these node types. - if (IsBlacklisted(node)) { + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; return Status::OK(); } @@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, return Status::OK(); } - // Check all inputs are Host friendly. - for (const GraphView::OutputPort& fanin : - graph.GetFanins(node, /*include_controlling_nodes=*/false)) { - bool fanin_candidate = false; - TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); - if (!fanin_candidate) { - return Status::OK(); - } - } - // Check all outputs are Host friendly. if (!properties->has_properties()) { // This is an expensive call, call it lazily. @@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, } } + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); + } + } + *is_candidate = true; return Status::OK(); } -string TryFindHostDevice(const gtl::FlatSet& devices, - bool has_device_cpu, const string& device) { +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} +} // end namespace + +// Tries to swap `device` to a Host device from `devices`. Returns true iff +// there was a swap. +bool TrySwapToHostDevice(const gtl::FlatSet& devices, + bool has_device_cpu, string* device) { // Force this node onto the CPU. - if (device.empty() && has_device_cpu) { - return "/device:CPU:0"; - } else if (str_util::StrContains(device, DEVICE_GPU)) { + if (device->empty() && has_device_cpu) { + *device = "/device:CPU:0"; + return true; + } else if (str_util::StrContains(*device, DEVICE_GPU)) { // Sometimes the cluster can have: // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} // and we need to handle them properly. @@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet& devices, {std::pair("GPU", "CPU:0"), std::pair("/device", "/device:CPU:0")}) { const string device_host = - strings::StrCat(device.substr(0, device.rfind(device_match.first)), + strings::StrCat(device->substr(0, device->rfind(device_match.first)), device_match.second); if (devices.find(device_host) != devices.end()) { - return device_host; + *device = device_host; + return true; } } } - // We couldn't find an appropriate Host device, return original device. - return device; -} - -bool IsTPUGraphDef(const GraphDef& def) { - for (const auto& node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || - node.op() == "TPUPartitionedCall") { - return true; - } - } + // We couldn't find an appropriate Host device, return false. return false; } + } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // All the Const nodes, and their original devices in topological order. std::vector> const_nodes; + // Cache to map {op, device, port} -> bool on whether it is pinned to host. + internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache; + internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache; + for (auto& node : *optimized_graph->mutable_node()) { bool is_candidate = false; - TF_RETURN_IF_ERROR( - internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate)); + TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate( + graph, &properties, node, &op_device_outport_pinned_to_host_cache, + &is_candidate)); if (!is_candidate) { continue; } - if (IsConstant(node)) { - const_nodes.emplace_back(&node, node.device()); + const string original_device = node.device(); + const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu, + node.mutable_device()); + // Keep track of all Const nodes that we swapped. + if (swapped && IsConstant(node)) { + const_nodes.emplace_back(&node, original_device); } - // Try and swap the device to Host. - node.set_device( - internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } // Traverse all `const_nodes`, and map them back to GPU greedily. @@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { // The consumer is not Host friendly, swap it back to the original device. - if (!internal::IsNodeInputPortHostFriendly(*fanout.node, - fanout.port_id)) { + if (!internal::IsNodeInputPortHostFriendly( + *fanout.node, fanout.port_id, + &op_device_inport_pinned_to_host_cache)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h index d557a03463..bed4a9ef95 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -26,8 +26,8 @@ namespace tensorflow { namespace grappler { namespace internal { // Try and find an appropriate Host device in `devices` given `device`. -string TryFindHostDevice(const gtl::FlatSet& devices, - bool has_device_cpu, const string& device); +bool TrySwapToHostDevice(const gtl::FlatSet& devices, + bool has_device_cpu, string* device); } // end namespace internal // Optimize TensorFlow ops that should be swapped into the CPU to avoid diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 7c64529441..9bb030b220 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,30 +28,60 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { gtl::FlatSet devices = {}; - EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); - - devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), - "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), - "/device:CPU:0"); - - devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_CPU:0"); - - devices = {"/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_GPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_GPU:*"); + + string device = "ABC"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "ABC"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { + gtl::FlatSet devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { + gtl::FlatSet devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { + gtl::FlatSet devices = {"/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { -- cgit v1.2.3