aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 18:00:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 18:04:36 -0700
commitf7edc2d308523fa6c2d233c09e3f2da1c98e3dbc (patch)
treee7d3d039f427ef3019d27723fb3c6a4aac60f381 /tensorflow/core
parentd6e14a53835eed5eed279c83e475440f8f814f0e (diff)
PinToHostOptimizer: Refactored code. Update blacklist. Added recursive lookback for Identity op. This fixes many performance regressions.
PiperOrigin-RevId: 215662393
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h4
-rw-r--r--tensorflow/core/grappler/graph_view.cc33
-rw-r--r--tensorflow/core/grappler/graph_view.h3
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc22
-rw-r--r--tensorflow/core/grappler/op_types.cc114
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc303
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc42
8 files changed, 366 insertions, 157 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index f716cd72c9..28fd7565cc 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -74,6 +74,10 @@ class GraphProperties {
// shape information.
void ClearInputProperties(const string& node_name);
void ClearOutputProperties(const string& node_name);
+ // Returns true if we have *any* properties.
+ bool has_properties() const {
+ return input_properties_.size() > 0 || output_properties_.size() > 0;
+ }
private:
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 0b8cb5e919..de0a63fc4e 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -20,23 +20,25 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
- for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
- ++output_arg_id) {
+namespace {
+int OpPortIdToArgId(const NodeDef& node,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ int port_id) {
+ for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
if (port_id < 0) {
return -1;
} else if (port_id == 0) {
- return output_arg_id;
+ return arg_id;
}
- // Default is 1 port per output arg.
+ // Default is 1 port per arg.
int n = 1;
- const auto& output_arg = op.output_arg(output_arg_id);
- if (!output_arg.number_attr().empty()) {
- n = node.attr().at(output_arg.number_attr()).i();
- } else if (!output_arg.type_list_attr().empty()) {
- n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ const auto& arg = args.Get(arg_id);
+ if (!arg.number_attr().empty()) {
+ n = node.attr().at(arg.number_attr()).i();
+ } else if (!arg.type_list_attr().empty()) {
+ n = node.attr().at(arg.type_list_attr()).list().type_size();
}
if (n < 0) {
@@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
DCHECK_GE(n, 0);
return -1;
} else if (port_id < n) {
- return output_arg_id;
+ return arg_id;
}
port_id -= n;
}
return -1;
}
+} // end namespace
+
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.output_arg(), port_id);
+}
+
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.input_arg(), port_id);
+}
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ec946ca3b5..09c36a1368 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Map a node/op's output port_id to arg_id.
+// Map a node/op's input/output port_id to arg_id.
//
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
// the n-th arg of the op. These two can be different if an op's arg is a list
@@ -34,6 +34,7 @@ namespace grappler {
//
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 3d7d2faf7c..f90e2c8cfc 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -26,7 +26,7 @@ namespace {
class GraphViewTest : public ::testing::Test {};
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
@@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
- EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
- EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+ // Const has 0 inputs, 1 output.
+ EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1));
+ // ShapeN has N=3 inputs and outputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
@@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
}
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) {
for (int num_splits : {1, 2}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
@@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+ // We have 4 inputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4));
+
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
int arg_id = -1;
if (port_id < num_splits * 3) {
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 9f0d9dbf28..1b5a215987 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unordered_set>
-
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -102,6 +101,18 @@ bool IsConjugateTranspose(const NodeDef& node) {
return node.op() == "ConjugateTranspose";
}
+bool IsControlFlow(const NodeDef& node) {
+ // clang-format off
+ return node.op() == "ControlTrigger" ||
+ node.op() == "Enter" ||
+ node.op() == "Exit" ||
+ node.op() == "LoopCond" ||
+ node.op() == "Merge" ||
+ node.op() == "NextIteration" ||
+ node.op() == "Switch";
+ // clang-format on
+}
+
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
@@ -140,26 +151,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
// e.g. inv.
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
- static const std::unordered_set<string>* monotonic_non_decreasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
"Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
"Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
}));
- static const std::unordered_set<string>* monotonic_non_increasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Inv",
"Reciprocal",
"Erfc",
"Rsqrt",
"Neg",
}));
- if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = true;
}
return true;
- } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = false;
}
@@ -431,6 +442,38 @@ bool IsSymbolicGradient(const NodeDef& node) {
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
+bool IsTensorArray(const NodeDef& node) {
+ static const gtl::FlatSet<string>* const kTensorArrayOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "TensorArray",
+ "TensorArrayV2",
+ "TensorArrayV3",
+ "TensorArrayGrad",
+ "TensorArrayGradV2",
+ "TensorArrayGradV3",
+ "TensorArrayGradWithShape",
+ "TensorArrayWrite",
+ "TensorArrayWriteV2",
+ "TensorArrayWriteV3",
+ "TensorArrayRead",
+ "TensorArrayReadV2",
+ "TensorArrayReadV3",
+ "TensorArrayConcat",
+ "TensorArrayConcatV2",
+ "TensorArrayConcatV3",
+ "TensorArraySplit",
+ "TensorArraySplitV2",
+ "TensorArraySplitV3",
+ "TensorArraySize",
+ "TensorArraySizeV2",
+ "TensorArraySizeV3",
+ "TensorArrayClose",
+ "TensorArrayCloseV2",
+ "TensorArrayCloseV3",
+ }));
+ return kTensorArrayOps->count(node.op()) > 0;
+}
+
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
@@ -542,30 +585,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
- static const std::unordered_set<string>* involution_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
- "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}));
- return involution_ops->count(node.op()) > 0;
+ static const gtl::FlatSet<string>* const kInvolutionOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
+ "Neg", "LogicalNot"}));
+ return kInvolutionOps->count(node.op()) > 0;
}
bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>*
- value_and_order_and_shape_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
- "PreventGradient",
- "Print",
- "Snapshot",
- "StopGradient",
- }));
- return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 ||
+ static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
IsIdentity(node);
}
@@ -573,31 +615,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>* value_and_order_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
"ExpandDims",
"Reshape",
"Squeeze",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
- static const std::unordered_set<string>* value_preserving_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValuePreservingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"InvertPermutation",
"Reverse",
"Roll",
"Transpose",
}));
return IsValueAndOrderPreserving(node) ||
- value_preserving_ops->count(node.op()) > 0;
+ kValuePreservingOps->count(node.op()) > 0;
}
bool IsUnaryElementWise(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kElementWiseOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Abs",
"Acos",
"Acosh",
@@ -646,7 +688,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
"Tan"
"Tanh",
}));
- return element_wise_ops->count(node.op()) > 0 ||
+ return kElementWiseOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 7f86a5f295..d4e0159e81 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
+bool IsControlFlow(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
bool IsConv2DBackpropFilter(const NodeDef& node);
bool IsConv2DBackpropInput(const NodeDef& node);
@@ -151,6 +152,7 @@ bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
+bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 89eb76046e..8ed4271fa4 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -35,13 +35,44 @@ namespace internal {
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
-// Find KernelDef for `node`.
-Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
- // Try find KernelDef for node.device, else GPU or CPU.
- for (const DeviceType& device :
- {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
- Status s = FindKernelDef(device, node, kdef, nullptr);
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // ControlFlow ops should not be swapped.
+ IsControlFlow(node) ||
+ // NoOp ops should not be swapped (due to group dependencies).
+ IsNoOp(node);
+}
+
+// Check if Tensor is integer and small size.
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+// Find KernelDef for `node`, greedily return first found from `devices`.
+Status TryFindKernelDef(const std::vector<DeviceType>& devices,
+ const NodeDef& node, const KernelDef** kdef) {
+ for (const DeviceType& device : devices) {
+ const KernelDef* kernel = nullptr;
+ Status s = FindKernelDef(device, node, &kernel, nullptr);
if (s.ok()) {
+ if (kdef) {
+ *kdef = kernel;
+ }
return Status::OK();
}
}
@@ -49,88 +80,183 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
return errors::NotFound("Could not find KernelDef for op: ", node.op());
}
-// Check if all node's inputs are pinned to CPU memory.
-bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
- // Loop through all the inputs excluding the controlling nodes.
- for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
- // Check if (the fanin) op's device is on CPU.
- if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
- continue;
- }
-
- // Check if (the fanin) op's output port is pinned to HostMemory.
- const OpDef* fanin_odef = nullptr;
- Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
- return false;
- }
+// Checks if a node's output port is host friendly.
+// Roughly this means checking if the output port is on Host memory.
+Status IsNodeOutputPortHostFriendly(const GraphView& graph,
+ GraphProperties* properties,
+ const NodeDef& node, int port_id,
+ bool* is_candidate) {
+ *is_candidate = false;
- const int output_arg_id =
- OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
- if (output_arg_id < 0) {
- LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
- << node.DebugString() << "\n"
- << fanin.node->DebugString() << "\n"
- << fanin_odef->DebugString();
- return false;
- }
+ // Make sure we are not a blacklisted op.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
- const KernelDef* fanin_kdef = nullptr;
- s = TryFindKernelDef(*fanin.node, &fanin_kdef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
- return false;
- }
+ // Check to make sure we have the right properties (i.e., statically shaped).
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ const auto& output_properties = properties->GetOutputProperties(node.name());
+ if (port_id >= output_properties.size()) {
+ LOG(WARNING) << "port_id=" << port_id
+ << " but output_properties.size()=" << output_properties.size()
+ << "\n"
+ << node.DebugString();
+ return Status::OK();
+ }
+ if (!IsTensorIntegerAndSmall(output_properties[port_id])) {
+ return Status::OK();
+ }
- bool fanin_pinned = false;
- for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
- if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
- fanin_pinned = true;
- break;
+ // These nodes may be optimized away downstream (even if pinned to Host), we
+ // should (recusively) check their source.
+ if (IsIdentity(node)) {
+ for (const auto& fanin : graph.GetFanins(node, false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
+ *is_candidate = true;
+ return Status::OK();
+ }
- if (!fanin_pinned) {
- return false;
+ // Check if op's device is on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check if op's output port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ return Status::OK();
+ }
+
+ // Map the port_id to output_arg_id.
+ const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << port_id << "!\n"
+ << node.DebugString() << "\n"
+ << op->DebugString();
+ return Status::OK();
+ }
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
+ &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ return Status::OK();
+ }
+
+ // Check if the output_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->output_arg(output_arg_id).name() == host_memory_arg) {
+ *is_candidate = true;
+ break;
}
}
- return true;
+ return Status::OK();
}
-bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
- // Check if Tensor is integer and small size.
+// Checks if a node's input port is Host friendly.
+// Roughly this means checking if the input port is on Host memory.
+bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+ // If node is on Host, assume its inputs are Host friendly.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ return true;
+ }
- // Check type to be int32 or int64.
- if (prop.dtype() != DataType::DT_INT32 &&
- prop.dtype() != DataType::DT_INT64) {
+ // Check if op's input port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
return false;
}
-
- // Check size known and small.
- const int64 size = NumCoefficients(prop.shape());
- if (size < 0 || size > kTensorMaxSize) {
+ const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = internal::TryFindKernelDef(
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
return false;
}
- return true;
+ // Check if the input_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ return true;
+ }
+ }
+
+ return false;
}
-bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
- const NodeDef& node) {
- for (const auto& prop : properties.GetInputProperties(node.name())) {
- if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+// Checks if a node is a candidate to pin to Host.
+// The rough algorithm is as follows:
+// 1] Check if node is blacklisted.
+// 2] Check if node can run on Host.
+// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
+// ints, and pinned to Host).
+Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
+ const NodeDef& node, bool* is_candidate) {
+ *is_candidate = false;
+
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
+ // Check the node can be run on CPU.
+ Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
+ if (!s.ok()) {
+ return Status::OK();
+ }
+
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
- for (const auto& prop : properties.GetOutputProperties(node.name())) {
+ // Check all outputs are Host friendly.
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+ return Status::OK();
}
}
- return true;
+
+ *is_candidate = true;
+ return Status::OK();
}
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
@@ -167,15 +293,6 @@ bool IsTPUGraphDef(const GraphDef& def) {
}
return false;
}
-
-// All the nodes that should be blacklisted and not swapped.
-bool IsBlacklisted(const NodeDef& node) {
- return
- // Collective ops should not be swapped.
- IsCollective(node) ||
- // NoOp breaks perf regression tests (probably due to group dependencies).
- IsNoOp(node);
-}
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -188,7 +305,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
GraphProperties properties(item);
- bool has_properties = false;
GraphView graph(optimized_graph);
gtl::FlatSet<string> devices;
@@ -209,35 +325,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
std::vector<std::pair<NodeDef*, string>> const_nodes;
for (auto& node : *optimized_graph->mutable_node()) {
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- continue;
- }
-
- // Skip these node types.
- if (internal::IsBlacklisted(node)) {
- continue;
- }
-
- // Check the node can be run on CPU.
- Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
- if (!s.ok()) {
- continue;
- }
-
- // Check all input's are pinned to CPU.
- if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
- continue;
- }
-
- if (!has_properties) {
- // This is an expensive call, call it lazily.
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
- has_properties = true;
- }
-
- // Check all inputs and outputs are integers and small.
- if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ bool is_candidate = false;
+ TF_RETURN_IF_ERROR(
+ internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ if (!is_candidate) {
continue;
}
@@ -254,10 +345,12 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* node = it.first;
const string& device = it.second;
- // Check all the consumers of this node, if any of them are on the original
- // device, swap this node back onto the original device.
+ // Check all the consumers of this node, if any of them are not on CPU, swap
+ // this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
- if (fanout.node->device() == device) {
+ // The consumer is not Host friendly, swap it back to the original device.
+ if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
+ fanout.port_id)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 173cb3fe3c..7c64529441 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -160,6 +160,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) {
EXPECT_EQ(found, 3);
}
+TEST_F(PinToHostOptimizerTest, Identity) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped.
+ // `b` should be placed onto Host since `c` pins the input to Host memory.
+ Output a =
+ ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64});
+ Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2});
+ Output c =
+ ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b);
+ Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c);
+ Output e = ops::Multiply(s.WithOpName("e"), d, d);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_EQ(node.device(), "/device:GPU:0");
+ } else if (node.name() == "b") {
+ // If CUDA, then there is a GPU kernel registration that is pinned to Host
+ // memory. Consequently, `b` will be mapped to Host correct if there is
+ // a GPU kernel registered.
+#if GOOGLE_CUDA
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+#else
+ EXPECT_TRUE(node.device().empty());
+#endif
+ } else if (node.name() == "d") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ } else if (node.name() == "e") {
+ EXPECT_TRUE(node.device().empty());
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 5);
+}
+
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});