diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 18:00:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 18:04:36 -0700 |
commit | f7edc2d308523fa6c2d233c09e3f2da1c98e3dbc (patch) | |
tree | e7d3d039f427ef3019d27723fb3c6a4aac60f381 /tensorflow/core/grappler/optimizers | |
parent | d6e14a53835eed5eed279c83e475440f8f814f0e (diff) |
PinToHostOptimizer: Refactored code. Update blacklist. Added recursive lookback for Identity op. This fixes many performance regressions.
PiperOrigin-RevId: 215662393
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc | 303 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc | 42 |
2 files changed, 240 insertions, 105 deletions
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 89eb76046e..8ed4271fa4 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -35,13 +35,44 @@ namespace internal { // dynamically determined. constexpr int64 kTensorMaxSize = 64; -// Find KernelDef for `node`. -Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { - // Try find KernelDef for node.device, else GPU or CPU. - for (const DeviceType& device : - {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { - Status s = FindKernelDef(device, node, kdef, nullptr); +// All the nodes that should be blacklisted and not swapped. +bool IsBlacklisted(const NodeDef& node) { + return + // Collective ops should not be swapped. + IsCollective(node) || + // ControlFlow ops should not be swapped. + IsControlFlow(node) || + // NoOp ops should not be swapped (due to group dependencies). + IsNoOp(node); +} + +// Check if Tensor is integer and small size. +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +// Find KernelDef for `node`, greedily return first found from `devices`. +Status TryFindKernelDef(const std::vector<DeviceType>& devices, + const NodeDef& node, const KernelDef** kdef) { + for (const DeviceType& device : devices) { + const KernelDef* kernel = nullptr; + Status s = FindKernelDef(device, node, &kernel, nullptr); if (s.ok()) { + if (kdef) { + *kdef = kernel; + } return Status::OK(); } } @@ -49,88 +80,183 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { return errors::NotFound("Could not find KernelDef for op: ", node.op()); } -// Check if all node's inputs are pinned to CPU memory. -bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { - // Loop through all the inputs excluding the controlling nodes. - for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { - // Check if (the fanin) op's device is on CPU. - if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { - continue; - } - - // Check if (the fanin) op's output port is pinned to HostMemory. - const OpDef* fanin_odef = nullptr; - Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); - if (!s.ok()) { - LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); - return false; - } +// Checks if a node's output port is host friendly. +// Roughly this means checking if the output port is on Host memory. +Status IsNodeOutputPortHostFriendly(const GraphView& graph, + GraphProperties* properties, + const NodeDef& node, int port_id, + bool* is_candidate) { + *is_candidate = false; - const int output_arg_id = - OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); - if (output_arg_id < 0) { - LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" - << node.DebugString() << "\n" - << fanin.node->DebugString() << "\n" - << fanin_odef->DebugString(); - return false; - } + // Make sure we are not a blacklisted op. + if (IsBlacklisted(node)) { + return Status::OK(); + } - const KernelDef* fanin_kdef = nullptr; - s = TryFindKernelDef(*fanin.node, &fanin_kdef); - if (!s.ok()) { - LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); - return false; - } + // Check to make sure we have the right properties (i.e., statically shaped). + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + const auto& output_properties = properties->GetOutputProperties(node.name()); + if (port_id >= output_properties.size()) { + LOG(WARNING) << "port_id=" << port_id + << " but output_properties.size()=" << output_properties.size() + << "\n" + << node.DebugString(); + return Status::OK(); + } + if (!IsTensorIntegerAndSmall(output_properties[port_id])) { + return Status::OK(); + } - bool fanin_pinned = false; - for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { - if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { - fanin_pinned = true; - break; + // These nodes may be optimized away downstream (even if pinned to Host), we + // should (recusively) check their source. + if (IsIdentity(node)) { + for (const auto& fanin : graph.GetFanins(node, false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } + *is_candidate = true; + return Status::OK(); + } - if (!fanin_pinned) { - return false; + // Check if op's device is on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check if op's output port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + return Status::OK(); + } + + // Map the port_id to output_arg_id. + const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << port_id << "!\n" + << node.DebugString() << "\n" + << op->DebugString(); + return Status::OK(); + } + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, + &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + return Status::OK(); + } + + // Check if the output_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->output_arg(output_arg_id).name() == host_memory_arg) { + *is_candidate = true; + break; } } - return true; + return Status::OK(); } -bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { - // Check if Tensor is integer and small size. +// Checks if a node's input port is Host friendly. +// Roughly this means checking if the input port is on Host memory. +bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { + // If node is on Host, assume its inputs are Host friendly. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + return true; + } - // Check type to be int32 or int64. - if (prop.dtype() != DataType::DT_INT32 && - prop.dtype() != DataType::DT_INT64) { + // Check if op's input port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); return false; } - - // Check size known and small. - const int64 size = NumCoefficients(prop.shape()); - if (size < 0 || size > kTensorMaxSize) { + const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = internal::TryFindKernelDef( + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); return false; } - return true; + // Check if the input_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->input_arg(input_arg_id).name() == host_memory_arg) { + return true; + } + } + + return false; } -bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, - const NodeDef& node) { - for (const auto& prop : properties.GetInputProperties(node.name())) { - if (!IsTensorIntegerAndSmall(prop)) { - return false; +// Checks if a node is a candidate to pin to Host. +// The rough algorithm is as follows: +// 1] Check if node is blacklisted. +// 2] Check if node can run on Host. +// 3] Check all input/outputs are Host "friendly" (atm, friendly means small, +// ints, and pinned to Host). +Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, + const NodeDef& node, bool* is_candidate) { + *is_candidate = false; + + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Skip these node types. + if (IsBlacklisted(node)) { + return Status::OK(); + } + + // Check the node can be run on CPU. + Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); + if (!s.ok()) { + return Status::OK(); + } + + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } - for (const auto& prop : properties.GetOutputProperties(node.name())) { + // Check all outputs are Host friendly. + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorIntegerAndSmall(prop)) { - return false; + return Status::OK(); } } - return true; + + *is_candidate = true; + return Status::OK(); } string TryFindHostDevice(const gtl::FlatSet<string>& devices, @@ -167,15 +293,6 @@ bool IsTPUGraphDef(const GraphDef& def) { } return false; } - -// All the nodes that should be blacklisted and not swapped. -bool IsBlacklisted(const NodeDef& node) { - return - // Collective ops should not be swapped. - IsCollective(node) || - // NoOp breaks perf regression tests (probably due to group dependencies). - IsNoOp(node); -} } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -188,7 +305,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } GraphProperties properties(item); - bool has_properties = false; GraphView graph(optimized_graph); gtl::FlatSet<string> devices; @@ -209,35 +325,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, std::vector<std::pair<NodeDef*, string>> const_nodes; for (auto& node : *optimized_graph->mutable_node()) { - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - continue; - } - - // Skip these node types. - if (internal::IsBlacklisted(node)) { - continue; - } - - // Check the node can be run on CPU. - Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); - if (!s.ok()) { - continue; - } - - // Check all input's are pinned to CPU. - if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { - continue; - } - - if (!has_properties) { - // This is an expensive call, call it lazily. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); - has_properties = true; - } - - // Check all inputs and outputs are integers and small. - if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + bool is_candidate = false; + TF_RETURN_IF_ERROR( + internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate)); + if (!is_candidate) { continue; } @@ -254,10 +345,12 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* node = it.first; const string& device = it.second; - // Check all the consumers of this node, if any of them are on the original - // device, swap this node back onto the original device. + // Check all the consumers of this node, if any of them are not on CPU, swap + // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { - if (fanout.node->device() == device) { + // The consumer is not Host friendly, swap it back to the original device. + if (!internal::IsNodeInputPortHostFriendly(*fanout.node, + fanout.port_id)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 173cb3fe3c..7c64529441 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -160,6 +160,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) { EXPECT_EQ(found, 3); } +TEST_F(PinToHostOptimizerTest, Identity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped. + // `b` should be placed onto Host since `c` pins the input to Host memory. + Output a = + ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64}); + Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2}); + Output c = + ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b); + Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c); + Output e = ops::Multiply(s.WithOpName("e"), d, d); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_EQ(node.device(), "/device:GPU:0"); + } else if (node.name() == "b") { + // If CUDA, then there is a GPU kernel registration that is pinned to Host + // memory. Consequently, `b` will be mapped to Host correct if there is + // a GPU kernel registered. +#if GOOGLE_CUDA + EXPECT_EQ(node.device(), "/device:CPU:0"); +#else + EXPECT_TRUE(node.device().empty()); +#endif + } else if (node.name() == "d") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } else if (node.name() == "e") { + EXPECT_TRUE(node.device().empty()); + } + ++found; + } + EXPECT_EQ(found, 5); +} + TEST_F(PinToHostOptimizerTest, PortIdToArgId) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); |