aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h4
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt7
-rw-r--r--tensorflow/core/grappler/graph_view.cc35
-rw-r--r--tensorflow/core/grappler/graph_view.h3
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc22
-rw-r--r--tensorflow/core/grappler/grappler_item.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item.h9
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc8
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
-rw-r--r--tensorflow/core/grappler/op_types.cc122
-rw-r--r--tensorflow/core/grappler/op_types.h3
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc130
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD78
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc32
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc45
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h38
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc40
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc44
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD13
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc40
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h24
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc662
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc459
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc53
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc156
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc401
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc118
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/utils.cc39
-rw-r--r--tensorflow/core/grappler/utils.h110
-rw-r--r--tensorflow/core/grappler/utils/functions.cc55
-rw-r--r--tensorflow/core/grappler/utils/functions.h5
-rw-r--r--tensorflow/core/grappler/utils_test.cc19
63 files changed, 2911 insertions, 883 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index f716cd72c9..28fd7565cc 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -74,6 +74,10 @@ class GraphProperties {
// shape information.
void ClearInputProperties(const string& node_name);
void ClearOutputProperties(const string& node_name);
+ // Returns true if we have *any* properties.
+ bool has_properties() const {
+ return input_properties_.size() > 0 || output_properties_.size() > 0;
+ }
private:
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 362092a6cf..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
index c94ee2f227..0ec95dd684 100644
--- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -88,6 +88,13 @@ library {
}
}
}
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ }
+ }
+ }
}
ret {
key: "while"
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 2619a9a8f3..de0a63fc4e 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -20,23 +20,25 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
- for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
- ++output_arg_id) {
+namespace {
+int OpPortIdToArgId(const NodeDef& node,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ int port_id) {
+ for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
if (port_id < 0) {
return -1;
} else if (port_id == 0) {
- return output_arg_id;
+ return arg_id;
}
- // Default is 1 port per output arg.
+ // Default is 1 port per arg.
int n = 1;
- const auto& output_arg = op.output_arg(output_arg_id);
- if (!output_arg.number_attr().empty()) {
- n = node.attr().at(output_arg.number_attr()).i();
- } else if (!output_arg.type_list_attr().empty()) {
- n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ const auto& arg = args.Get(arg_id);
+ if (!arg.number_attr().empty()) {
+ n = node.attr().at(arg.number_attr()).i();
+ } else if (!arg.type_list_attr().empty()) {
+ n = node.attr().at(arg.type_list_attr()).list().type_size();
}
if (n < 0) {
@@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
DCHECK_GE(n, 0);
return -1;
} else if (port_id < n) {
- return output_arg_id;
+ return arg_id;
}
port_id -= n;
}
return -1;
}
+} // end namespace
+
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.output_arg(), port_id);
+}
+
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.input_arg(), port_id);
+}
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
@@ -72,7 +83,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ec946ca3b5..09c36a1368 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Map a node/op's output port_id to arg_id.
+// Map a node/op's input/output port_id to arg_id.
//
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
// the n-th arg of the op. These two can be different if an op's arg is a list
@@ -34,6 +34,7 @@ namespace grappler {
//
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 3d7d2faf7c..f90e2c8cfc 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -26,7 +26,7 @@ namespace {
class GraphViewTest : public ::testing::Test {};
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
@@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
- EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
- EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+ // Const has 0 inputs, 1 output.
+ EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1));
+ // ShapeN has N=3 inputs and outputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
@@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
}
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) {
for (int num_splits : {1, 2}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
@@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+ // We have 4 inputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4));
+
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
int arg_id = -1;
if (port_id < num_splits * 3) {
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index bbc0fedd22..2c490f3966 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
+ allowed_optimizations = other.allowed_optimizations;
graph.Swap(graph_def);
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 939e5fa046..a0748abfe6 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -77,6 +77,15 @@ struct GrapplerItem {
// Return a set of node names that must be preserved. This includes feed and
// fetch nodes, keep_ops, init_ops.
std::unordered_set<string> NodesToPreserve() const;
+
+ // Restrict types of optimizations that are allowed for this GrapplerItem.
+ struct AllowedOptimizations {
+ // Is it allowed to add nodes to the graph that do not have registered
+ // gradient function.
+ bool non_differentiable_rewrites = true;
+ };
+
+ AllowedOptimizations allowed_optimizations;
};
// Return the transitive fanin of a set of terminal nodes.
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 029515ad3c..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 3521669b63..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unordered_set>
-
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -102,6 +101,22 @@ bool IsConjugateTranspose(const NodeDef& node) {
return node.op() == "ConjugateTranspose";
}
+bool IsControlFlow(const NodeDef& node) {
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
+}
+
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
@@ -140,26 +155,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
// e.g. inv.
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
- static const std::unordered_set<string>* monotonic_non_decreasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
"Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
"Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
}));
- static const std::unordered_set<string>* monotonic_non_increasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Inv",
"Reciprocal",
"Erfc",
"Rsqrt",
"Neg",
}));
- if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = true;
}
return true;
- } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = false;
}
@@ -425,8 +440,44 @@ bool IsSwitch(const NodeDef& node) {
return op == "Switch" || op == "RefSwitch";
}
+bool IsSymbolicGradient(const NodeDef& node) {
+ return node.op() == "SymbolicGradient";
+}
+
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
+bool IsTensorArray(const NodeDef& node) {
+ static const gtl::FlatSet<string>* const kTensorArrayOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "TensorArray",
+ "TensorArrayV2",
+ "TensorArrayV3",
+ "TensorArrayGrad",
+ "TensorArrayGradV2",
+ "TensorArrayGradV3",
+ "TensorArrayGradWithShape",
+ "TensorArrayWrite",
+ "TensorArrayWriteV2",
+ "TensorArrayWriteV3",
+ "TensorArrayRead",
+ "TensorArrayReadV2",
+ "TensorArrayReadV3",
+ "TensorArrayConcat",
+ "TensorArrayConcatV2",
+ "TensorArrayConcatV3",
+ "TensorArraySplit",
+ "TensorArraySplitV2",
+ "TensorArraySplitV3",
+ "TensorArraySize",
+ "TensorArraySizeV2",
+ "TensorArraySizeV3",
+ "TensorArrayClose",
+ "TensorArrayCloseV2",
+ "TensorArrayCloseV3",
+ }));
+ return kTensorArrayOps->count(node.op()) > 0;
+}
+
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
@@ -538,30 +589,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
- static const std::unordered_set<string>* involution_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
- "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}));
- return involution_ops->count(node.op()) > 0;
+ static const gtl::FlatSet<string>* const kInvolutionOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
+ "Neg", "LogicalNot"}));
+ return kInvolutionOps->count(node.op()) > 0;
}
bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>*
- value_and_order_and_shape_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
- "PreventGradient",
- "Print",
- "Snapshot",
- "StopGradient",
- }));
- return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 ||
+ static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
IsIdentity(node);
}
@@ -569,31 +619,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>* value_and_order_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
"ExpandDims",
"Reshape",
"Squeeze",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
- static const std::unordered_set<string>* value_preserving_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValuePreservingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"InvertPermutation",
"Reverse",
"Roll",
"Transpose",
}));
return IsValueAndOrderPreserving(node) ||
- value_preserving_ops->count(node.op()) > 0;
+ kValuePreservingOps->count(node.op()) > 0;
}
bool IsUnaryElementWise(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kElementWiseOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Abs",
"Acos",
"Acosh",
@@ -642,7 +692,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
"Tan"
"Tanh",
}));
- return element_wise_ops->count(node.op()) > 0 ||
+ return kElementWiseOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 25ab6b65ac..d4e0159e81 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
+bool IsControlFlow(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
bool IsConv2DBackpropFilter(const NodeDef& node);
bool IsConv2DBackpropInput(const NodeDef& node);
@@ -149,7 +150,9 @@ bool IsStridedSliceGrad(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
+bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
+bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 960d1addb3..c708f84948 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -525,6 +525,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
@@ -541,6 +542,7 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 75ed12635e..7d5014ee0a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
@@ -3247,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ // Disable restricted graph rewrites.
+ options_.unary_ops_composition &=
+ item.allowed_optimizations.non_differentiable_rewrites;
+
if (options_.dedup_computations) {
DedupComputations();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ca5d3a6dfd..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b09360a2c2..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index cf305cebe1..ee7c14e3ab 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,6 +22,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -31,6 +32,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -87,11 +89,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -121,11 +124,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -135,6 +139,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -146,6 +151,62 @@ tf_cc_test(
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -256,7 +317,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -265,6 +326,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -294,6 +356,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -302,6 +365,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -339,6 +403,7 @@ tf_cc_test(
srcs = ["map_parallelization_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_parallelization",
"//tensorflow/core:framework",
@@ -422,6 +487,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
@@ -458,7 +524,9 @@ cc_library(
deps = [
":function_utils",
":graph_utils",
+ "//tensorflow/cc:ops",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -474,6 +542,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -483,7 +552,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
index c71aa6e804..1ad495bbad 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
fused_node.set_op("FilterDataset");
fused_node.add_input(first_filter_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = first_filter_node.attr().at("predicate");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["predicate"] = std::move(attr);
- copy_attribute("Targuments", first_filter_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, second_filter_node, &fused_node);
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
return fused_node;
}
@@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// functions, or make sure that optimization passes run after filter
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(first_filter_node->name());
nodes_to_delete.insert(second_filter_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
index 12b1924efd..c8becc5cc0 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,14 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
index e95ea1a4c1..311df15bc2 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -14,31 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace grappler {
namespace function_utils {
-namespace {
-
-template <typename Predicate, typename Collection>
-std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
- const Collection& collection) {
- std::vector<int> indices = {};
- unsigned idx = 0;
- for (auto&& element : collection) {
- if (predicate(element)) {
- indices.push_back(idx);
- }
- idx++;
- }
- return indices;
-}
-
-} // namespace
FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
const string& output, int position)
@@ -152,32 +137,27 @@ bool ContainsFunctionOutputWithName(StringPiece name,
}
int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().input_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().output_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
function.node_def());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; },
function.node_def());
-
- return indices.empty() ? -1 : indices.front();
}
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 2dd9ee822e..b863a25dc5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -201,25 +202,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const FunctionDef& function) {
return function.signature().name() == name;
},
library.function());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
@@ -260,6 +258,41 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
+
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
+Status EnsureNodeNamesUnique(Graph* g) {
+ // Modeled after Scope::Impl::GetUniqueName
+ std::unordered_map<string, int> name_map;
+
+ for (auto node : g->op_nodes()) {
+ const string& prefix = node->name();
+ if (auto entry = gtl::FindOrNull(name_map, prefix)) {
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, "_", ++(*entry));
+ } while (name_map.find(unique_name) != name_map.end());
+ name_map.insert({unique_name, 0});
+ node->set_name(std::move(unique_name));
+ } else {
+ name_map.insert({node->name(), 0});
+ }
+ }
+
+ return Status::OK();
+}
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b117482db2..d130fee204 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -31,6 +32,21 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
@@ -101,11 +117,29 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
+// Checks that all nodes in the graphs have unique names, and sets their names
+// to be unique if they are not already. This is necessary as Graph does not
+// have the provisions to deduplicate names, and name deduplication elsewhere
+// in tensorflow happens in other layers (for example, in the Scope class of the
+// C++ API). Note that the nodes in the graph are identified by their id,
+// and renaming nodes does not mutate any edges.
+Status EnsureNodeNamesUnique(Graph* g);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 6877c207c4..4ab6d71532 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -24,6 +25,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -217,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) {
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
+TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
+ Graph g(OpRegistry::Global());
+
+ Node *const_0, *const_1, *const_2;
+
+ // Arbitrary const
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int32>()() = 5;
+
+ for (auto node : {&const_0, &const_1}) {
+ TF_EXPECT_OK(NodeBuilder("Const", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, node));
+ }
+ // Make sure generated name doesn't clash with existing name either
+ TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, &const_2));
+
+ TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
+ EXPECT_NE(const_0->name(), const_1->name());
+ EXPECT_NE(const_1->name(), const_2->name());
+ EXPECT_NE(const_0->name(), const_2->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 63945b8b9e..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index f1844a141c..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index f029a093fa..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index a78ecb09f7..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index b25dfbd0b8..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 305325e434..782c9f48b7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
-
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index b2a5d9b6af..9fdfe8af30 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,16 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
- StringPiece function_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
+using graph_tests_utils::MakeMapNode;
const char stateless_fun_name[] = "XTimesTwo";
const char stateful_fun_name[] = "RandomUniform";
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 7a2f1910da..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -35,10 +35,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
- (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
-}
-
// Returns a FunctionDef containing a MapDefun op that wraps the original
// function.
FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
@@ -48,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
// Function inputs and outputs are the same as original, just
// with different shapes.
*vectorized_func->mutable_signature() = orig_func.signature();
- graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
+ graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library,
vectorized_func);
// Add MapDefun node
@@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
- CopyAttribute(k, map_node, map_defun_node);
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
@@ -71,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
@@ -90,21 +87,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
@@ -195,13 +190,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
}
// Set attrs
- CopyAttribute("Targuments", old_map_node, &map_node);
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
- CopyAttribute(key, old_batch_node, &map_node);
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
}
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
return map_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index cb0ff670e8..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,24 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,41 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,40 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
+ Status s;
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -56,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -31,19 +31,19 @@ class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of `inputs`. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -39,12 +38,14 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -17,274 +17,588 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
+
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ op_node->type_string());
+ }
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
- if (vectorizer == nullptr) {
- return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ found = &conversion_map_.at(output);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
+ } else {
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
+ }
+}
+
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
+
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
}
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
- } else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
}
- *converted = conversion_map_.at(output_desc.full_str);
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->node) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
return Status::OK();
}
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+Status Vectorization::AddArgNodeMappings() {
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {input_node, Graph::kControlSlot, true}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
+ }
+
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->node, found->output_index, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
}
}
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor(node, Graph::kControlSlot, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
}
+ return Status::OK();
+}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
+
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
+
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
+
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -54,12 +55,18 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +92,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +139,17 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +162,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +178,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +233,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +307,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +385,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +471,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +551,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +570,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +631,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +650,278 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
}
// TODO(rachelim): More test cases when we get around to implementing them:
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c59645e5f2..3f33b16ba8 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -106,7 +107,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -115,6 +117,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -172,11 +177,12 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ std::set<string> initialized_custom_optimizers;
for (const string& optimizer_name : cfg_.optimizers()) {
auto optimizer = MakeNewOptimizer(optimizer_name);
if (optimizer) {
@@ -190,18 +196,26 @@ Status MetaOptimizer::InitializeOptimizersByName(
if (custom_optimizer) {
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
- TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ TF_RETURN_IF_ERROR(custom_optimizer->Init(
+ GetCustomGraphOptimizerConfig(optimizer_name)));
optimizers->push_back(std::move(custom_optimizer));
+ initialized_custom_optimizers.insert(optimizer_name);
} else {
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
+ optimizers);
}
Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
+ if (pre_initialized_optimizers.find(optimizer_config.name()) !=
+ pre_initialized_optimizers.end()) {
+ continue;
+ }
// Initialize the ExperimentalImplementationSelector here instead of
// CustomizeOptimizer registry, due the static link issue in TensorRT for
// double registry.
@@ -237,6 +251,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
return Status::OK();
}
+const RewriterConfig::CustomGraphOptimizer*
+MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
+ for (const auto& config : cfg_.custom_optimizers()) {
+ if (config.name() == name) {
+ return &config;
+ }
+ }
+ return nullptr;
+}
+
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
@@ -391,6 +415,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
+ // Find functions for which we might need to compute a gradient at runtime.
+ gtl::FlatSet<string> differentiable_functions;
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (IsSymbolicGradient(node)) {
+ const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
+ if (f_attr) differentiable_functions.insert(f_attr->func().name());
+ }
+ }
+
// Optimize each function only once.
std::unordered_set<string> optimized_funcs;
bool optimize_function_library = true;
@@ -406,6 +439,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).
+ // They should be specialized to their instantiation type parameters by
+ // the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name;
@@ -420,6 +455,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
func, flib, item.graph.versions().producer(), &func_item));
+ // If we need to compute the gradient of optimized function at runtime, we
+ // can't perform non-differentiable rewrites.
+ if (differentiable_functions.find(func_name) !=
+ differentiable_functions.end()) {
+ func_item.allowed_optimizations.non_differentiable_rewrites = false;
+ }
+
// Optimize function body graph.
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(
@@ -470,6 +512,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 831c5e37c0..99a0a33ffa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Initialize active optimizers from RewriterConfig.custom_optimizers.
Status InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Returns the config for a custom graph optimizer. Null if none was found.
+ const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
+ const string& name) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index e74e0f7501..3f3f43382f 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -71,6 +72,59 @@ class TestGraphOptimizer : public TestOptimizer {
REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+class TestOptimizerWithParams : public TestOptimizer {
+ public:
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ CHECK(config != nullptr);
+ return Status::OK();
+ }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+
+// Record various properties of the GrapplerItems passed for optimization.
+class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
+ public:
+ static void SetAllowedOptimizations(
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations) {
+ allowed_optimizations_ = allowed_optimizations;
+ }
+ static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; }
+
+ GrapplerItemPropertiesAccumulator() {}
+ string name() const override {
+ return "grappler_item_properties_accumulator";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ if (allowed_optimizations_) {
+ allowed_optimizations_->insert({item.id, item.allowed_optimizations});
+ }
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations_;
+};
+
+gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ GrapplerItemPropertiesAccumulator::allowed_optimizations_;
+
+REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -90,6 +144,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizerWithParams");
+ auto* custom_config = rewriter_config.add_custom_optimizers();
+ custom_config->set_name("TestOptimizerWithParams");
+ (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -305,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
+ using test::function::NDef;
+ using FDH = FunctionDefHelper;
+
+ // We will record what type of optimizations meta optimizer allows for each
+ // GrapplerItem (main graph and graphs for each function).
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>
+ allowed_optimizations;
+ GrapplerItemPropertiesAccumulator::SetAllowedOptimizations(
+ &allowed_optimizations);
+
+ // Just record properties of optimized Grappler items.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define simple function library with two identical mul functions.
+ FunctionDef mul_func_1 = FunctionDefHelper::Create(
+ "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef mul_func_2 = FunctionDefHelper::Create(
+ "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ // Tensorflow graph:
+ //
+ // x0 = tf.Placeholder(tf.float);
+ // x1 = tf.Placeholder(tf.float);
+ // dy = tf.Placeholder(tf.float);
+ //
+ // mul_1 = MyMul1(x0, x1);
+ // mul_2 = MyMul2(x0, x1);
+ // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
+ GrapplerItem item;
+ item.id = "main";
+ item.graph = test::function::GDef(
+ {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ // Calls into function library
+ NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
+ NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
+ // Symbolic gradient of a MyMul2
+ NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
+ {{"f", FDH::FunctionRef("MyMul2", {})},
+ {"Tin", DataTypeSlice{DT_FLOAT}},
+ {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
+ kDevice)},
+ // FunctionLib
+ {mul_func_1, mul_func_2});
+ item.fetch = {"mul_1", "mul_2", "dx"};
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Our custom optimizer must be called for the main graph and for the two
+ // functions.
+ ASSERT_EQ(allowed_optimizations.size(), 3);
+
+ auto allowed_optimizations_main =
+ gtl::FindOrNull(allowed_optimizations, "main");
+ ASSERT_NE(allowed_optimizations_main, nullptr);
+ EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_1 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul1");
+ ASSERT_NE(allowed_optimizations_my_mul_1, nullptr);
+ EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_2 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul2");
+ ASSERT_NE(allowed_optimizations_my_mul_2, nullptr);
+ EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 2190d38937..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,23 +25,67 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
-// Find KernelDef for `node`.
-Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
- // Try find KernelDef for node.device, else GPU or CPU.
- for (const DeviceType& device :
- {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
- Status s = FindKernelDef(device, node, kdef, nullptr);
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // ControlFlow ops should not be swapped.
+ IsControlFlow(node) ||
+ // NoOp ops should not be swapped (due to group dependencies).
+ IsNoOp(node);
+}
+
+// Check if Tensor is integer and small size.
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+// Find KernelDef for `node`, greedily return first found from `devices`.
+Status TryFindKernelDef(const std::vector<DeviceType>& devices,
+ const NodeDef& node, const KernelDef** kdef) {
+ for (const DeviceType& device : devices) {
+ const KernelDef* kernel = nullptr;
+ Status s = FindKernelDef(device, node, &kernel, nullptr);
if (s.ok()) {
+ if (kdef) {
+ *kdef = kernel;
+ }
return Status::OK();
}
}
@@ -49,96 +93,239 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
return errors::NotFound("Could not find KernelDef for op: ", node.op());
}
-// Check if all node's inputs are pinned to CPU memory.
-bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
- // Loop through all the inputs excluding the controlling nodes.
- for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
- // Check if (the fanin) op's device is on CPU.
- if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
- continue;
- }
-
- // Check if (the fanin) op's output port is pinned to HostMemory.
- const OpDef* fanin_odef = nullptr;
- Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
- return false;
- }
+// Checks if a node's output port is host friendly.
+// Roughly this means checking if the output port is on Host memory.
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
+ *is_candidate = false;
- const int output_arg_id =
- OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
- if (output_arg_id < 0) {
- LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
- << node.DebugString() << "\n"
- << fanin.node->DebugString() << "\n"
- << fanin_odef->DebugString();
- return false;
- }
+ // Make sure we are not a blacklisted op.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
- const KernelDef* fanin_kdef = nullptr;
- s = TryFindKernelDef(*fanin.node, &fanin_kdef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
- return false;
- }
+ // Check to make sure we have the right properties (i.e., statically shaped).
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ const auto& output_properties = properties->GetOutputProperties(node.name());
+ if (port_id >= output_properties.size()) {
+ LOG(WARNING) << "port_id=" << port_id
+ << " but output_properties.size()=" << output_properties.size()
+ << "\n"
+ << node.DebugString();
+ return Status::OK();
+ }
+ if (!IsTensorIntegerAndSmall(output_properties[port_id])) {
+ return Status::OK();
+ }
- bool fanin_pinned = false;
- for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
- if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
- fanin_pinned = true;
- break;
+ // These nodes may be optimized away downstream (even if pinned to Host), we
+ // should (recusively) check their source.
+ if (IsIdentity(node)) {
+ for (const auto& fanin : graph.GetFanins(node, false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check if op's device is on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
+ // Check if op's output port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
+
+ // Map the port_id to output_arg_id.
+ const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << port_id << "!\n"
+ << node.DebugString() << "\n"
+ << op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
- if (!fanin_pinned) {
- return false;
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
+ &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
+
+ // Check if the output_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->output_arg(output_arg_id).name() == host_memory_arg) {
+ *is_candidate = true;
+ break;
}
}
- return true;
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
+ return Status::OK();
}
-bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
- // Check if Tensor is integer and small size.
+// Checks if a node's input port is Host friendly.
+// Roughly this means checking if the input port is on Host memory.
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
+ // If node is on Host, assume its inputs are Host friendly.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ return true;
+ }
- // Check type to be int32 or int64.
- if (prop.dtype() != DataType::DT_INT32 &&
- prop.dtype() != DataType::DT_INT64) {
- return false;
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
}
- // Check size known and small.
- const int64 size = NumCoefficients(prop.shape());
- if (size < 0 || size > kTensorMaxSize) {
+ // Check if op's input port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+ return false;
+ }
+ const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = internal::TryFindKernelDef(
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
- return true;
+ // Check if the input_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
+ return true;
+ }
+ }
+
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
+ return false;
}
-bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
- const NodeDef& node) {
- for (const auto& prop : properties.GetInputProperties(node.name())) {
+// Checks if a node is a candidate to pin to Host.
+// The rough algorithm is as follows:
+// 1] Check if node is blacklisted.
+// 2] Check if node can run on Host.
+// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
+// ints, and pinned to Host).
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
+ *is_candidate = false;
+
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check the node can be run on CPU.
+ Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
+ if (!s.ok()) {
+ return Status::OK();
+ }
+
+ // Check all outputs are Host friendly.
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+ return Status::OK();
}
}
- for (const auto& prop : properties.GetOutputProperties(node.name())) {
- if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
- return true;
+
+ *is_candidate = true;
+ return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -146,30 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
-// All the nodes that should be blacklisted and not swapped.
-bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -182,7 +358,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
GraphProperties properties(item);
- bool has_properties = false;
GraphView graph(optimized_graph);
gtl::FlatSet<string> devices;
@@ -202,45 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
- for (auto& node : *optimized_graph->mutable_node()) {
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- continue;
- }
-
- // Skip these node types.
- if (internal::IsBlacklisted(node)) {
- continue;
- }
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
- // Check the node can be run on CPU.
- Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
- if (!s.ok()) {
- continue;
- }
-
- // Check all input's are pinned to CPU.
- if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
- continue;
- }
-
- if (!has_properties) {
- // This is an expensive call, call it lazily.
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
- has_properties = true;
- }
-
- // Check all inputs and outputs are integers and small.
- if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ for (auto& node : *optimized_graph->mutable_node()) {
+ bool is_candidate = false;
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
+ if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -248,10 +404,13 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* node = it.first;
const string& device = it.second;
- // Check all the consumers of this node, if any of them are on the original
- // device, swap this node back onto the original device.
+ // Check all the consumers of this node, if any of them are not on CPU, swap
+ // this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
- if (fanout.node->device() == device) {
+ // The consumer is not Host friendly, swap it back to the original device.
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 173cb3fe3c..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
@@ -160,6 +190,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) {
EXPECT_EQ(found, 3);
}
+TEST_F(PinToHostOptimizerTest, Identity) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped.
+ // `b` should be placed onto Host since `c` pins the input to Host memory.
+ Output a =
+ ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64});
+ Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2});
+ Output c =
+ ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b);
+ Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c);
+ Output e = ops::Multiply(s.WithOpName("e"), d, d);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_EQ(node.device(), "/device:GPU:0");
+ } else if (node.name() == "b") {
+ // If CUDA, then there is a GPU kernel registration that is pinned to Host
+ // memory. Consequently, `b` will be mapped to Host correct if there is
+ // a GPU kernel registered.
+#if GOOGLE_CUDA
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+#else
+ EXPECT_TRUE(node.device().empty());
+#endif
+ } else if (node.name() == "d") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ } else if (node.name() == "e") {
+ EXPECT_TRUE(node.device().empty());
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 5);
+}
+
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 008a289cfd..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 4542d17ccc..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index db6e4e6852..5867d01324 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -156,45 +156,6 @@ bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
-string NodeName(const string& name) {
- int position;
- return ParseNodeName(name, &position);
-}
-
-int NodePosition(const string& name) {
- int position;
- ParseNodeNameAsStringPiece(name, &position);
- return position;
-}
-
-int NodePositionIfSameNode(const string& input_name, const string& node_name) {
- const bool is_ctrl = input_name[0] == '^';
- auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
- auto node_it = node_name.begin();
- if (node_name.empty() ||
- std::distance(input_it, input_name.end()) < node_name.size()) {
- return -2;
- }
- while (node_it != node_name.end()) {
- if (*input_it++ != *node_it++) {
- return -2;
- }
- }
- if (input_it == input_name.end()) {
- return is_ctrl ? -1 : 0;
- } else if (*input_it++ == ':') {
- StringPiece remaining(&(*input_it),
- std::distance(input_it, input_name.end()));
- int position;
- if (!strings::safe_strto32(remaining, &position)) {
- return -2;
- }
- return is_ctrl ? -1 : position;
- } else {
- return -2;
- }
-}
-
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 296ee1678e..95126d470c 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -102,40 +101,92 @@ bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
+// Returns the trailing position number (or zero if no number is present) if
+// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
+// Returns -2 if NodeName(input_name) is not equal to node_name.
+// Note: This function is used very heavily, and this hand-optimized
+// version is 3-4x faster than the version using Scanner, which it replaced.
+// This is worth the reduction in readability.
+inline int NodePositionIfSameNode(const string& input_name,
+ const string& node_name) {
+ if (input_name.empty()) return -2;
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
-string NodeName(const string& name);
+inline StringPiece NodeNameAsStringPiece(const string& name) {
+ static const string empty;
+ if (name.empty()) return StringPiece(empty);
+ const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ if (end_it != name.end() && *end_it != ':') {
+ return StringPiece(empty);
+ }
+ return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
+}
-// Get the trailing position number ":{digits}" (if any) of a node name.
-// Returns -1 for control inputs.
-int NodePosition(const string& name);
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+inline string NodeName(const string& name) {
+ return string(NodeNameAsStringPiece(name));
+}
+// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ static const string empty;
+ if (name.empty()) {
*position = 0;
- static const string empty;
return StringPiece(empty);
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ const bool is_ctrl = name[0] == '^';
+ const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
+ *position = is_ctrl ? -1 : 0;
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
+ if (end_it != name.end()) {
+ if (*end_it != ':') {
+ return StringPiece(empty);
+ } else if (!is_ctrl) {
+ ++end_it;
+ StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
+ if (!strings::safe_strto32(remaining, position)) {
+ return StringPiece(empty);
+ }
}
- return capture;
}
+ return node_name;
}
// Returns the node name and position in a single call.
@@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
-// Returns NodePosition(input_name) if NodeName(input_name) == node_name.
-// Otherwise returns -2;
-// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0.
-int NodePositionIfSameNode(const string& input_name, const string& node_name);
+inline int NodePosition(const string& name) {
+ int position;
+ ParseNodeNameAsStringPiece(name, &position);
+ return position;
+}
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a428aea7f5..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -460,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 733caf325f..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index 6b787a6910..9b6c1f690b 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -371,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+#define BM_ParseNodeNameAsStringPiece(I, NAME) \
+ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
+ string input = I; \
+ for (int i = 0; i < iters; ++i) { \
+ int position; \
+ const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
+ CHECK_GE(position, -1); \
+ CHECK(!name.empty()); \
+ } \
+ } \
+ BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME)
+
+BM_ParseNodeNameAsStringPiece("foo", foo);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl);
+BM_ParseNodeNameAsStringPiece("foo:123", foo123);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+
} // namespace
} // namespace grappler
} // namespace tensorflow