aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-12 10:37:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 10:42:39 -0700
commit89177f289e9467e04b205a1a3e705ad67d9854d2 (patch)
tree92c907bbf012eaac91eebf6147c52f2a2dea29a6
parentaab543c3013e3018d409ed2b8cd957f3465d1ab2 (diff)
Turn trivial Pack ops with a single input into ExpandDims ops to avoid copying the tensor.
PiperOrigin-RevId: 188742516
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc70
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc42
5 files changed, 97 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index ca56833ef6..53c177befc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -217,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) {
return op == "NextIteration" || op == "RefNextIteration";
}
+bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
+
bool IsPad(const NodeDef& node) {
const auto& op = node.op();
return op == "Pad" || op == "PadV2";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index a0946ee1ad..cd5b464099 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -86,6 +86,7 @@ bool IsMod(const NodeDef& node);
bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
+bool IsPack(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 39cc4a9629..6cb0447355 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
}
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties,
+ GraphProperties* properties,
bool use_shape_info) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < output->node_size(); ++i) {
@@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
if (use_shape_info &&
(IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size 1)
bool replaceable = !shape.unknown_rank();
@@ -1533,10 +1533,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (use_shape_info && IsSlice(*node) &&
- properties.GetInputProperties(node->name()).size() == 3) {
- const auto& input = properties.GetInputProperties(node->name())[0];
- const auto& b = properties.GetInputProperties(node->name())[1];
- const auto& s = properties.GetInputProperties(node->name())[2];
+ properties->GetInputProperties(node->name()).size() == 3) {
+ const auto& input = properties->GetInputProperties(node->name())[0];
+ const auto& b = properties->GetInputProperties(node->name())[1];
+ const auto& s = properties->GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
@@ -1574,8 +1574,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (IsTile(*node) &&
- properties.GetInputProperties(node->name()).size() == 2) {
- const auto& m = properties.GetInputProperties(node->name())[1];
+ properties->GetInputProperties(node->name()).size() == 2) {
+ const auto& m = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
@@ -1602,8 +1602,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (IsPad(*node) &&
- properties.GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties.GetInputProperties(node->name())[1];
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
@@ -1625,12 +1625,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (use_shape_info && IsSqueeze(*node) &&
- !properties.GetInputProperties(node->name()).empty()) {
+ !properties->GetInputProperties(node->name()).empty()) {
// https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
// error to squeeze a dimension that is not 1, so we only need to check
// whether the input has > 1 size for each dimension.
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size > 1)
bool replaceable = !shape.unknown_rank();
@@ -1642,6 +1642,38 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
}
+ if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+ !OptimizedNodeExists(*node, "_const_axis")) {
+ // Create constant axis node.
+ Tensor axis_t(DT_INT32, TensorShape({}));
+ NodeDef* axis_node = output->add_node();
+ axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+ const int axis = node->attr().at("axis").i();
+ if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+ !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
+ .ok()) {
+ continue;
+ }
+ VLOG(1) << "*** Rewriting trivial Pack node: " << node->DebugString();
+ // Add a control dependency to make sure axis_node is in the right frame.
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ node->input(0), graph_, node_map_.get());
+ axis_node->add_input(ctrl_dep);
+ axis_node->set_device(node->device());
+ node->set_op("ExpandDims");
+ if (node->attr().count("axis") != 0) {
+ node->mutable_attr()->erase("axis");
+ }
+ if (node->attr().count("N") != 0) {
+ node->mutable_attr()->erase("N");
+ }
+ (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+ node->add_input(axis_node->name());
+ if (node->input_size() > 2) {
+ node->mutable_input()->SwapElements(1, node->input_size() - 1);
+ }
+ }
+
// Switch(x, x) will always feed false to its false branch and true to
// its true branch. By rewriting the graph a bit, we can propagate these
// constants down the two output branches, and just use control dependencies
@@ -1759,7 +1791,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
graph_modified_ = true;
continue;
}
- if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+ if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
@@ -1777,8 +1809,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties.HasInputProperties(node->name()) &&
- properties.HasOutputProperties(node->name())) {
+ properties->HasInputProperties(node->name()) &&
+ properties->HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
@@ -1786,12 +1818,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties.GetOutputProperties(node->name())[0].shape();
+ properties->GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties.GetInputProperties(node->name())[1].shape();
+ properties->GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
@@ -1818,7 +1850,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
const TensorShapeProto& x_shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
@@ -2139,7 +2171,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
TF_RETURN_IF_ERROR(FoldGraph(output));
node_map_.reset(new NodeMap(output));
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 2fd59c7f9c..13ecfcd281 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer {
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index f421a59989..724fb84f3e 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1930,6 +1930,48 @@ TEST_F(ConstantFoldingTest, IdenticalN) {
EXPECT_EQ("^id_n", output.node(7).input(2));
}
+TEST_F(ConstantFoldingTest, TrivialPack) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output x =
+ ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
+ Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
+ auto stack =
+ ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
+ ops::Stack::Axis(1));
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch.push_back("stack");
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ for (const auto& node : output.node()) {
+ if (node.name() == "stack") {
+ EXPECT_EQ("stack", node.name());
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
+ EXPECT_EQ("^y", node.input(2));
+ } else if (node.name() == "ConstantFolding/stack_const_axis") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^x", node.input(0));
+ }
+ }
+
+ std::vector<string> fetch = {"stack"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow