From 129bb5e845ccb2ab6339e85d39545800dac6ca33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Oct 2018 23:42:02 -0700 Subject: Automated rollback of commit 5f308cb408eb46ec9af0546be6b9ae1d5166b185 PiperOrigin-RevId: 216309111 --- tensorflow/core/grappler/op_types.cc | 22 ++- .../grappler/optimizers/pin_to_host_optimizer.cc | 162 +++++++-------------- .../grappler/optimizers/pin_to_host_optimizer.h | 4 +- .../optimizers/pin_to_host_optimizer_test.cc | 76 +++------- 4 files changed, 85 insertions(+), 179 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index cbf5c8e038..1b5a215987 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -102,19 +102,15 @@ bool IsConjugateTranspose(const NodeDef& node) { } bool IsControlFlow(const NodeDef& node) { - // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative - // string comparison. - static const gtl::FlatSet* const kControFlowOps = - CHECK_NOTNULL((new gtl::FlatSet{ - "ControlTrigger", - "Enter", - "Exit", - "LoopCond", - "Merge", - "NextIteration", - "Switch", - })); - return kControFlowOps->count(node.op()) > 0; + // clang-format off + return node.op() == "ControlTrigger" || + node.op() == "Enter" || + node.op() == "Exit" || + node.op() == "LoopCond" || + node.op() == "Merge" || + node.op() == "NextIteration" || + node.op() == "Switch"; + // clang-format on } bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 29a3b2b74c..8ed4271fa4 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -25,29 +25,16 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { - namespace internal { -namespace { // TODO(williamchan): Change this constant to be something smarter, maybe // dynamically determined. constexpr int64 kTensorMaxSize = 64; -struct OpDevicePortHasher { - std::size_t operator()(const std::tuple& x) const { - uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x))); - - return Hash64Combine(code, hash()(std::get<2>(x))); - } -}; -using OpDevicePortOnHostMap = - gtl::FlatMap, bool, OpDevicePortHasher>; - // All the nodes that should be blacklisted and not swapped. bool IsBlacklisted(const NodeDef& node) { return @@ -95,10 +82,10 @@ Status TryFindKernelDef(const std::vector& devices, // Checks if a node's output port is host friendly. // Roughly this means checking if the output port is on Host memory. -Status IsNodeOutputPortHostFriendly( - const GraphView& graph, GraphProperties* properties, const NodeDef& node, - int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, - bool* is_candidate) { +Status IsNodeOutputPortHostFriendly(const GraphView& graph, + GraphProperties* properties, + const NodeDef& node, int port_id, + bool* is_candidate) { *is_candidate = false; // Make sure we are not a blacklisted op. @@ -130,8 +117,7 @@ Status IsNodeOutputPortHostFriendly( for (const auto& fanin : graph.GetFanins(node, false)) { bool fanin_candidate = false; TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, - op_device_outport_pinned_to_host_cache, &fanin_candidate)); + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); if (!fanin_candidate) { return Status::OK(); } @@ -146,22 +132,11 @@ Status IsNodeOutputPortHostFriendly( return Status::OK(); } - // Check `op_device_outport_pinned_to_host_cache` for our - // {op, device, port_id} combo to see if the arg is pinned on Host. - const std::tuple cache_key(node.op(), node.device(), - port_id); - auto it = op_device_outport_pinned_to_host_cache->find(cache_key); - if (it != op_device_outport_pinned_to_host_cache->end()) { - *is_candidate = it->second; - return Status::OK(); - } - // Check if op's output port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); - op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -171,7 +146,6 @@ Status IsNodeOutputPortHostFriendly( LOG(WARNING) << "Invalid port: " << port_id << "!\n" << node.DebugString() << "\n" << op->DebugString(); - op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -181,7 +155,6 @@ Status IsNodeOutputPortHostFriendly( &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); - op_device_outport_pinned_to_host_cache->emplace(cache_key, false); return Status::OK(); } @@ -193,35 +166,22 @@ Status IsNodeOutputPortHostFriendly( } } - op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate); - return Status::OK(); } // Checks if a node's input port is Host friendly. // Roughly this means checking if the input port is on Host memory. -bool IsNodeInputPortHostFriendly( - const NodeDef& node, int port_id, - OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) { +bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { // If node is on Host, assume its inputs are Host friendly. if (str_util::StrContains(node.device(), DEVICE_CPU)) { return true; } - // Check `op_device_inport_pinned_to_host_cache` for our - // {op, device, port_id} combo to see if the arg is pinned on Host. - std::tuple cache_key(node.op(), node.device(), port_id); - auto it = op_device_inport_pinned_to_host_cache->find(cache_key); - if (it != op_device_inport_pinned_to_host_cache->end()) { - return it->second; - } - // Check if op's input port is pinned to HostMemory. const OpDef* op = nullptr; Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); - op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); @@ -232,20 +192,16 @@ bool IsNodeInputPortHostFriendly( {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); - op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } // Check if the input_arg is pinned to Host. for (const string& host_memory_arg : kernel->host_memory_arg()) { if (op->input_arg(input_arg_id).name() == host_memory_arg) { - op_device_inport_pinned_to_host_cache->emplace(cache_key, true); return true; } } - op_device_inport_pinned_to_host_cache->emplace(cache_key, false); - return false; } @@ -255,29 +211,38 @@ bool IsNodeInputPortHostFriendly( // 2] Check if node can run on Host. // 3] Check all input/outputs are Host "friendly" (atm, friendly means small, // ints, and pinned to Host). -Status IsNodeHostCandidate( - const GraphView& graph, GraphProperties* properties, const NodeDef& node, - OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, - bool* is_candidate) { +Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, + const NodeDef& node, bool* is_candidate) { *is_candidate = false; - // Skip these node types. - if (IsBlacklisted(node)) { - return Status::OK(); - } - // Check if node already on CPU. if (str_util::StrContains(node.device(), DEVICE_CPU)) { *is_candidate = true; return Status::OK(); } + // Skip these node types. + if (IsBlacklisted(node)) { + return Status::OK(); + } + // Check the node can be run on CPU. Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); if (!s.ok()) { return Status::OK(); } + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); + } + } + // Check all outputs are Host friendly. if (!properties->has_properties()) { // This is an expensive call, call it lazily. @@ -290,42 +255,16 @@ Status IsNodeHostCandidate( } } - // Check all inputs are Host friendly. - for (const GraphView::OutputPort& fanin : - graph.GetFanins(node, /*include_controlling_nodes=*/false)) { - bool fanin_candidate = false; - TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( - graph, properties, *fanin.node, fanin.port_id, - op_device_outport_pinned_to_host_cache, &fanin_candidate)); - if (!fanin_candidate) { - return Status::OK(); - } - } - *is_candidate = true; return Status::OK(); } -bool IsTPUGraphDef(const GraphDef& def) { - for (const auto& node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || - node.op() == "TPUPartitionedCall") { - return true; - } - } - return false; -} -} // end namespace - -// Tries to swap `device` to a Host device from `devices`. Returns true iff -// there was a swap. -bool TrySwapToHostDevice(const gtl::FlatSet& devices, - bool has_device_cpu, string* device) { +string TryFindHostDevice(const gtl::FlatSet& devices, + bool has_device_cpu, const string& device) { // Force this node onto the CPU. - if (device->empty() && has_device_cpu) { - *device = "/device:CPU:0"; - return true; - } else if (str_util::StrContains(*device, DEVICE_GPU)) { + if (device.empty() && has_device_cpu) { + return "/device:CPU:0"; + } else if (str_util::StrContains(device, DEVICE_GPU)) { // Sometimes the cluster can have: // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} // and we need to handle them properly. @@ -333,19 +272,27 @@ bool TrySwapToHostDevice(const gtl::FlatSet& devices, {std::pair("GPU", "CPU:0"), std::pair("/device", "/device:CPU:0")}) { const string device_host = - strings::StrCat(device->substr(0, device->rfind(device_match.first)), + strings::StrCat(device.substr(0, device.rfind(device_match.first)), device_match.second); if (devices.find(device_host) != devices.end()) { - *device = device_host; - return true; + return device_host; } } } - // We couldn't find an appropriate Host device, return false. - return false; + // We couldn't find an appropriate Host device, return original device. + return device; } +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -377,26 +324,20 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // All the Const nodes, and their original devices in topological order. std::vector> const_nodes; - // Cache to map {op, device, port} -> bool on whether it is pinned to host. - internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache; - internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache; - for (auto& node : *optimized_graph->mutable_node()) { bool is_candidate = false; - TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate( - graph, &properties, node, &op_device_outport_pinned_to_host_cache, - &is_candidate)); + TF_RETURN_IF_ERROR( + internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate)); if (!is_candidate) { continue; } - const string original_device = node.device(); - const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu, - node.mutable_device()); - // Keep track of all Const nodes that we swapped. - if (swapped && IsConstant(node)) { - const_nodes.emplace_back(&node, original_device); + if (IsConstant(node)) { + const_nodes.emplace_back(&node, node.device()); } + // Try and swap the device to Host. + node.set_device( + internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } // Traverse all `const_nodes`, and map them back to GPU greedily. @@ -408,9 +349,8 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { // The consumer is not Host friendly, swap it back to the original device. - if (!internal::IsNodeInputPortHostFriendly( - *fanout.node, fanout.port_id, - &op_device_inport_pinned_to_host_cache)) { + if (!internal::IsNodeInputPortHostFriendly(*fanout.node, + fanout.port_id)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h index bed4a9ef95..d557a03463 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -26,8 +26,8 @@ namespace tensorflow { namespace grappler { namespace internal { // Try and find an appropriate Host device in `devices` given `device`. -bool TrySwapToHostDevice(const gtl::FlatSet& devices, - bool has_device_cpu, string* device); +string TryFindHostDevice(const gtl::FlatSet& devices, + bool has_device_cpu, const string& device); } // end namespace internal // Optimize TensorFlow ops that should be swapped into the CPU to avoid diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 9bb030b220..7c64529441 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,60 +28,30 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { +TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { gtl::FlatSet devices = {}; - - string device = "ABC"; - EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_EQ(device, "ABC"); -} - -TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { - gtl::FlatSet devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - - string device = ""; - EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); - EXPECT_EQ(device, "/device:CPU:0"); - - device = "/device:XLA_GPU:0"; - EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); - EXPECT_EQ(device, "/device:CPU:0"); - - device = "/device:XLA_GPU:*"; - EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); - EXPECT_EQ(device, "/device:CPU:0"); -} - -TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { - gtl::FlatSet devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - - string device = ""; - EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_TRUE(device.empty()); - - device = "/device:XLA_GPU:0"; - EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_EQ(device, "/device:XLA_CPU:0"); - - device = "/device:XLA_GPU:*"; - EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_EQ(device, "/device:XLA_CPU:0"); -} - -TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { - gtl::FlatSet devices = {"/device:XLA_GPU:0"}; - - string device = ""; - EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_TRUE(device.empty()); - - device = "/device:XLA_GPU:0"; - EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_EQ(device, "/device:XLA_GPU:0"); - - device = "/device:XLA_GPU:*"; - EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); - EXPECT_EQ(device, "/device:XLA_GPU:*"); + EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); + + devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), + "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), + "/device:CPU:0"); + + devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_CPU:0"); + + devices = {"/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_GPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { -- cgit v1.2.3