diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 18:00:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 18:04:36 -0700 |
commit | f7edc2d308523fa6c2d233c09e3f2da1c98e3dbc (patch) | |
tree | e7d3d039f427ef3019d27723fb3c6a4aac60f381 /tensorflow/core | |
parent | d6e14a53835eed5eed279c83e475440f8f814f0e (diff) |
PinToHostOptimizer: Refactored code. Update blacklist. Added recursive lookback for Identity op. This fixes many performance regressions.
PiperOrigin-RevId: 215662393
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.h | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view.cc | 33 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view.h | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view_test.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 114 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc | 303 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc | 42 |
8 files changed, 366 insertions, 157 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index f716cd72c9..28fd7565cc 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -74,6 +74,10 @@ class GraphProperties { // shape information. void ClearInputProperties(const string& node_name); void ClearOutputProperties(const string& node_name); + // Returns true if we have *any* properties. + bool has_properties() const { + return input_properties_.size() > 0 || output_properties_.size() > 0; + } private: // Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index 0b8cb5e919..de0a63fc4e 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -20,23 +20,25 @@ limitations under the License. namespace tensorflow { namespace grappler { -int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { - for (int output_arg_id = 0; output_arg_id < op.output_arg_size(); - ++output_arg_id) { +namespace { +int OpPortIdToArgId(const NodeDef& node, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + int port_id) { + for (int arg_id = 0; arg_id < args.size(); ++arg_id) { if (port_id < 0) { return -1; } else if (port_id == 0) { - return output_arg_id; + return arg_id; } - // Default is 1 port per output arg. + // Default is 1 port per arg. int n = 1; - const auto& output_arg = op.output_arg(output_arg_id); - if (!output_arg.number_attr().empty()) { - n = node.attr().at(output_arg.number_attr()).i(); - } else if (!output_arg.type_list_attr().empty()) { - n = node.attr().at(output_arg.type_list_attr()).list().type_size(); + const auto& arg = args.Get(arg_id); + if (!arg.number_attr().empty()) { + n = node.attr().at(arg.number_attr()).i(); + } else if (!arg.type_list_attr().empty()) { + n = node.attr().at(arg.type_list_attr()).list().type_size(); } if (n < 0) { @@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { DCHECK_GE(n, 0); return -1; } else if (port_id < n) { - return output_arg_id; + return arg_id; } port_id -= n; } return -1; } +} // end namespace + +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.output_arg(), port_id); +} + +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.input_arg(), port_id); +} GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index ec946ca3b5..09c36a1368 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Map a node/op's output port_id to arg_id. +// Map a node/op's input/output port_id to arg_id. // // The port_id refers to the n-th tensor of the node, while the arg_id refers to // the n-th arg of the op. These two can be different if an op's arg is a list @@ -34,6 +34,7 @@ namespace grappler { // // We return -1 for any invalid port_id (i.e., no corresponding arg_id). int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); // A utility class to simplify the traversal of a GraphDef. class GraphView { diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 3d7d2faf7c..f90e2c8cfc 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -26,7 +26,7 @@ namespace { class GraphViewTest : public ::testing::Test {}; -TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { +TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); ops::ShapeN b(s.WithOpName("b"), {a, a, a}); @@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); - EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0)); - EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1)); + // Const has 0 inputs, 1 output. + EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1)); + // ShapeN has N=3 inputs and outputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2)); @@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4)); } -TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { +TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) { for (int num_splits : {1, 2}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10}); @@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + // We have 4 inputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4)); + for (int port_id = 0; port_id <= num_splits * 3; ++port_id) { int arg_id = -1; if (port_id < num_splits * 3) { diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 9f0d9dbf28..1b5a215987 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <unordered_set> - +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -102,6 +101,18 @@ bool IsConjugateTranspose(const NodeDef& node) { return node.op() == "ConjugateTranspose"; } +bool IsControlFlow(const NodeDef& node) { + // clang-format off + return node.op() == "ControlTrigger" || + node.op() == "Enter" || + node.op() == "Exit" || + node.op() == "LoopCond" || + node.op() == "Merge" || + node.op() == "NextIteration" || + node.op() == "Switch"; + // clang-format on +} + bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2DBackpropFilter(const NodeDef& node) { @@ -140,26 +151,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, // e.g. inv. bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { - static const std::unordered_set<string>* monotonic_non_decreasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", })); - static const std::unordered_set<string>* monotonic_non_increasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Inv", "Reciprocal", "Erfc", "Rsqrt", "Neg", })); - if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (kMonotonicNonDecreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = true; } return true; - } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = false; } @@ -431,6 +442,38 @@ bool IsSymbolicGradient(const NodeDef& node) { bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } +bool IsTensorArray(const NodeDef& node) { + static const gtl::FlatSet<string>* const kTensorArrayOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "TensorArray", + "TensorArrayV2", + "TensorArrayV3", + "TensorArrayGrad", + "TensorArrayGradV2", + "TensorArrayGradV3", + "TensorArrayGradWithShape", + "TensorArrayWrite", + "TensorArrayWriteV2", + "TensorArrayWriteV3", + "TensorArrayRead", + "TensorArrayReadV2", + "TensorArrayReadV3", + "TensorArrayConcat", + "TensorArrayConcatV2", + "TensorArrayConcatV3", + "TensorArraySplit", + "TensorArraySplitV2", + "TensorArraySplitV3", + "TensorArraySize", + "TensorArraySizeV2", + "TensorArraySizeV3", + "TensorArrayClose", + "TensorArrayCloseV2", + "TensorArrayCloseV3", + })); + return kTensorArrayOps->count(node.op()) > 0; +} + bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } @@ -542,30 +585,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate) OPDEF_PROPERTY_HELPER(Commutative, commutative) bool IsInvolution(const NodeDef& node) { - static const std::unordered_set<string>* involution_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ - "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"})); - return involution_ops->count(node.op()) > 0; + static const gtl::FlatSet<string>* const kInvolutionOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert", + "Neg", "LogicalNot"})); + return kInvolutionOps->count(node.op()) > 0; } bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* - value_and_order_and_shape_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ - "CheckNumerics", - "DebugGradientIdentity", - "DeepCopy" - "Enter", - "Exit", - "PreventGradient", - "Print", - "Snapshot", - "StopGradient", - })); - return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 || + static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ + "CheckNumerics", + "DebugGradientIdentity", + "DeepCopy" + "Enter", + "Exit", + "PreventGradient", + "Print", + "Snapshot", + "StopGradient", + })); + return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 || IsIdentity(node); } @@ -573,31 +615,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* value_and_order_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ "ExpandDims", "Reshape", "Squeeze", })); - return value_and_order_preserving_ops->count(node.op()) > 0 || + return kValueAndOrderPreservingOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } bool IsValuePreserving(const NodeDef& node) { - static const std::unordered_set<string>* value_preserving_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValuePreservingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "InvertPermutation", "Reverse", "Roll", "Transpose", })); return IsValueAndOrderPreserving(node) || - value_preserving_ops->count(node.op()) > 0; + kValuePreservingOps->count(node.op()) > 0; } bool IsUnaryElementWise(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kElementWiseOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Abs", "Acos", "Acosh", @@ -646,7 +688,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tan" "Tanh", })); - return element_wise_ops->count(node.op()) > 0 || + return kElementWiseOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 7f86a5f295..d4e0159e81 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); +bool IsControlFlow(const NodeDef& node); bool IsConv2D(const NodeDef& node); bool IsConv2DBackpropFilter(const NodeDef& node); bool IsConv2DBackpropInput(const NodeDef& node); @@ -151,6 +152,7 @@ bool IsSum(const NodeDef& node); bool IsSwitch(const NodeDef& node); bool IsSymbolicGradient(const NodeDef& node); bool IsTanhGrad(const NodeDef& node); +bool IsTensorArray(const NodeDef& node); bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 89eb76046e..8ed4271fa4 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -35,13 +35,44 @@ namespace internal { // dynamically determined. constexpr int64 kTensorMaxSize = 64; -// Find KernelDef for `node`. -Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { - // Try find KernelDef for node.device, else GPU or CPU. - for (const DeviceType& device : - {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { - Status s = FindKernelDef(device, node, kdef, nullptr); +// All the nodes that should be blacklisted and not swapped. +bool IsBlacklisted(const NodeDef& node) { + return + // Collective ops should not be swapped. + IsCollective(node) || + // ControlFlow ops should not be swapped. + IsControlFlow(node) || + // NoOp ops should not be swapped (due to group dependencies). + IsNoOp(node); +} + +// Check if Tensor is integer and small size. +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +// Find KernelDef for `node`, greedily return first found from `devices`. +Status TryFindKernelDef(const std::vector<DeviceType>& devices, + const NodeDef& node, const KernelDef** kdef) { + for (const DeviceType& device : devices) { + const KernelDef* kernel = nullptr; + Status s = FindKernelDef(device, node, &kernel, nullptr); if (s.ok()) { + if (kdef) { + *kdef = kernel; + } return Status::OK(); } } @@ -49,88 +80,183 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { return errors::NotFound("Could not find KernelDef for op: ", node.op()); } -// Check if all node's inputs are pinned to CPU memory. -bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { - // Loop through all the inputs excluding the controlling nodes. - for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { - // Check if (the fanin) op's device is on CPU. - if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { - continue; - } - - // Check if (the fanin) op's output port is pinned to HostMemory. - const OpDef* fanin_odef = nullptr; - Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); - if (!s.ok()) { - LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); - return false; - } +// Checks if a node's output port is host friendly. +// Roughly this means checking if the output port is on Host memory. +Status IsNodeOutputPortHostFriendly(const GraphView& graph, + GraphProperties* properties, + const NodeDef& node, int port_id, + bool* is_candidate) { + *is_candidate = false; - const int output_arg_id = - OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); - if (output_arg_id < 0) { - LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" - << node.DebugString() << "\n" - << fanin.node->DebugString() << "\n" - << fanin_odef->DebugString(); - return false; - } + // Make sure we are not a blacklisted op. + if (IsBlacklisted(node)) { + return Status::OK(); + } - const KernelDef* fanin_kdef = nullptr; - s = TryFindKernelDef(*fanin.node, &fanin_kdef); - if (!s.ok()) { - LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); - return false; - } + // Check to make sure we have the right properties (i.e., statically shaped). + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + const auto& output_properties = properties->GetOutputProperties(node.name()); + if (port_id >= output_properties.size()) { + LOG(WARNING) << "port_id=" << port_id + << " but output_properties.size()=" << output_properties.size() + << "\n" + << node.DebugString(); + return Status::OK(); + } + if (!IsTensorIntegerAndSmall(output_properties[port_id])) { + return Status::OK(); + } - bool fanin_pinned = false; - for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { - if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { - fanin_pinned = true; - break; + // These nodes may be optimized away downstream (even if pinned to Host), we + // should (recusively) check their source. + if (IsIdentity(node)) { + for (const auto& fanin : graph.GetFanins(node, false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } + *is_candidate = true; + return Status::OK(); + } - if (!fanin_pinned) { - return false; + // Check if op's device is on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check if op's output port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + return Status::OK(); + } + + // Map the port_id to output_arg_id. + const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << port_id << "!\n" + << node.DebugString() << "\n" + << op->DebugString(); + return Status::OK(); + } + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, + &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + return Status::OK(); + } + + // Check if the output_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->output_arg(output_arg_id).name() == host_memory_arg) { + *is_candidate = true; + break; } } - return true; + return Status::OK(); } -bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { - // Check if Tensor is integer and small size. +// Checks if a node's input port is Host friendly. +// Roughly this means checking if the input port is on Host memory. +bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) { + // If node is on Host, assume its inputs are Host friendly. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + return true; + } - // Check type to be int32 or int64. - if (prop.dtype() != DataType::DT_INT32 && - prop.dtype() != DataType::DT_INT64) { + // Check if op's input port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); return false; } - - // Check size known and small. - const int64 size = NumCoefficients(prop.shape()); - if (size < 0 || size > kTensorMaxSize) { + const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = internal::TryFindKernelDef( + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); return false; } - return true; + // Check if the input_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->input_arg(input_arg_id).name() == host_memory_arg) { + return true; + } + } + + return false; } -bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, - const NodeDef& node) { - for (const auto& prop : properties.GetInputProperties(node.name())) { - if (!IsTensorIntegerAndSmall(prop)) { - return false; +// Checks if a node is a candidate to pin to Host. +// The rough algorithm is as follows: +// 1] Check if node is blacklisted. +// 2] Check if node can run on Host. +// 3] Check all input/outputs are Host "friendly" (atm, friendly means small, +// ints, and pinned to Host). +Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, + const NodeDef& node, bool* is_candidate) { + *is_candidate = false; + + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Skip these node types. + if (IsBlacklisted(node)) { + return Status::OK(); + } + + // Check the node can be run on CPU. + Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); + if (!s.ok()) { + return Status::OK(); + } + + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } - for (const auto& prop : properties.GetOutputProperties(node.name())) { + // Check all outputs are Host friendly. + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorIntegerAndSmall(prop)) { - return false; + return Status::OK(); } } - return true; + + *is_candidate = true; + return Status::OK(); } string TryFindHostDevice(const gtl::FlatSet<string>& devices, @@ -167,15 +293,6 @@ bool IsTPUGraphDef(const GraphDef& def) { } return false; } - -// All the nodes that should be blacklisted and not swapped. -bool IsBlacklisted(const NodeDef& node) { - return - // Collective ops should not be swapped. - IsCollective(node) || - // NoOp breaks perf regression tests (probably due to group dependencies). - IsNoOp(node); -} } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -188,7 +305,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } GraphProperties properties(item); - bool has_properties = false; GraphView graph(optimized_graph); gtl::FlatSet<string> devices; @@ -209,35 +325,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, std::vector<std::pair<NodeDef*, string>> const_nodes; for (auto& node : *optimized_graph->mutable_node()) { - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - continue; - } - - // Skip these node types. - if (internal::IsBlacklisted(node)) { - continue; - } - - // Check the node can be run on CPU. - Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); - if (!s.ok()) { - continue; - } - - // Check all input's are pinned to CPU. - if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { - continue; - } - - if (!has_properties) { - // This is an expensive call, call it lazily. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); - has_properties = true; - } - - // Check all inputs and outputs are integers and small. - if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + bool is_candidate = false; + TF_RETURN_IF_ERROR( + internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate)); + if (!is_candidate) { continue; } @@ -254,10 +345,12 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* node = it.first; const string& device = it.second; - // Check all the consumers of this node, if any of them are on the original - // device, swap this node back onto the original device. + // Check all the consumers of this node, if any of them are not on CPU, swap + // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { - if (fanout.node->device() == device) { + // The consumer is not Host friendly, swap it back to the original device. + if (!internal::IsNodeInputPortHostFriendly(*fanout.node, + fanout.port_id)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 173cb3fe3c..7c64529441 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -160,6 +160,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) { EXPECT_EQ(found, 3); } +TEST_F(PinToHostOptimizerTest, Identity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped. + // `b` should be placed onto Host since `c` pins the input to Host memory. + Output a = + ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64}); + Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2}); + Output c = + ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b); + Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c); + Output e = ops::Multiply(s.WithOpName("e"), d, d); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_EQ(node.device(), "/device:GPU:0"); + } else if (node.name() == "b") { + // If CUDA, then there is a GPU kernel registration that is pinned to Host + // memory. Consequently, `b` will be mapped to Host correct if there is + // a GPU kernel registered. +#if GOOGLE_CUDA + EXPECT_EQ(node.device(), "/device:CPU:0"); +#else + EXPECT_TRUE(node.device().empty()); +#endif + } else if (node.name() == "d") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } else if (node.name() == "e") { + EXPECT_TRUE(node.device().empty()); + } + ++found; + } + EXPECT_EQ(found, 5); +} + TEST_F(PinToHostOptimizerTest, PortIdToArgId) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); |