aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 15:17:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 15:22:01 -0700
commitbd99ed794264668ce77ed7527bc41df7aba3927b (patch)
treedde13a43e89ce175d888a02b34252aeaadf7ca54 /tensorflow/core
parent26d3617d2ab5f4874b73059be524e94b9535465b (diff)
Fix bug in Grappler constant folding: The logic detecting full reductions was flawed. Added better test coverage.
Also added a extra test for a related symbolic shape inference operation that I first suspected to be broken. PiperOrigin-RevId: 215812753
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc130
3 files changed, 118 insertions, 65 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 362092a6cf..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ca5d3a6dfd..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b09360a2c2..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {