diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 15:17:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 15:22:01 -0700 |
commit | bd99ed794264668ce77ed7527bc41df7aba3927b (patch) | |
tree | dde13a43e89ce175d888a02b34252aeaadf7ca54 /tensorflow/core | |
parent | 26d3617d2ab5f4874b73059be524e94b9535465b (diff) |
Fix bug in Grappler constant folding: The logic detecting full reductions was flawed. Added better test coverage.
Also added a extra test for a related symbolic shape inference operation that I first suspected to be broken.
PiperOrigin-RevId: 215812753
Diffstat (limited to 'tensorflow/core')
3 files changed, 118 insertions, 65 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 362092a6cf..db10f586bc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); Output g = ops::Shape(s.WithOpName("g"), c); Output h = ops::Fill(s.WithOpName("h"), g, zero); + Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1}); + Output j = ops::Sum(s.WithOpName("j"), a, zero_idx); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { ASSERT_EQ(2, shape_f.dim_size()); EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size()); EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size()); + + const auto shape_j = properties.GetOutputProperties("j").at(0).shape(); + ASSERT_EQ(1, shape_j.dim_size()); + EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size()); } TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ca5d3a6dfd..3d0d95bba7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices( // We can't do anything if we don't know the rank of the input. return Status::OK(); } - const int rank = input_prop.shape().dim_size(); - if (rank == 0) { + const int input_rank = input_prop.shape().dim_size(); + if (input_rank < 1) { // Unexpected graph, don't try to change it. return Status::OK(); } + const OpInfo::TensorProperties& reduction_indices_prop = input_props[1]; + DataType dtype = reduction_indices_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape()); + const int num_reduction_indices = reduction_indices_shape.num_elements(); + const std::vector<OpInfo::TensorProperties>& output_props = properties.GetOutputProperties(node->name()); if (output_props.size() != 1) { return Status::OK(); } - const bool keep_dims = - node->attr().count("keep_dims") && node->attr().at("keep_dims").b(); const OpInfo::TensorProperties& output_prop = output_props[0]; - PartialTensorShape output_shape(output_prop.shape()); - if (output_shape.num_elements() != 1) { - bool full_reduction = false; + const int output_rank = + output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size(); + + bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank; + if (!full_reduction) { + // A full reduction will generate a tensor of one of the shapes + // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of + // elements in the output of the reduction, we may deduce it from reshape + // nodes following it. for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { - if (!IsReshape(*fanout) && !keep_dims) { - // Depending on how it's setup, a full reduction will generate a tensor - // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we - // rely on the existence of a reshape node following the reduction to - // ensure that the fanout is fed a scalar of the right shape. + full_reduction = false; + if (!IsReshape(*fanout)) { return Status::OK(); } const std::vector<OpInfo::TensorProperties>& reshape_props = @@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices( } } - const OpInfo::TensorProperties& reduction_prop = input_props[1]; - DataType dtype = reduction_prop.dtype(); - if (dtype != DT_INT32 && dtype != DT_INT64) { - return Status::OK(); - } - // We know it's a full reduction. We can generate the set of indices to - // reduce. + // We know it's a full reduction. We can generate the full set of indices to + // reduce as a constant node. string const_name = OptimizedNodeName(*node, "-reduction_indices"); if (node_map_->GetNode(const_name)) { return Status::OK(); } NodeDef* reduction_indices = graph_->add_node(); - Tensor value(dtype, TensorShape({rank})); - for (int i = 0; i < rank; ++i) { + Tensor value(dtype, TensorShape({input_rank})); + for (int i = 0; i < input_rank; ++i) { if (dtype == DT_INT32) { value.vec<int32>()(i) = i; } else { @@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices( } TF_RETURN_IF_ERROR( CreateNodeDef(const_name, TensorValue(&value), reduction_indices)); + reduction_indices->set_device(node->device()); string ctrl_dep = AddControlDependency(node->input(1), graph_, node_map_.get()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b09360a2c2..fab01edfed 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) { } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output input = - ops::Placeholder(s.WithOpName("input"), DT_FLOAT, - ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); - Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); - Output sum = ops::Sum(s.WithOpName("sum"), input, indices); - Output size = ops::Const(s.WithOpName("size"), 1, {1}); - Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + for (bool use_reshape : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + // If use_reshape is false, we need to now the number of indices to apply + // the rewrite. + Output indices = ops::Placeholder( + s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + if (use_reshape) { + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + } - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("reshape"); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back(use_reshape ? "reshape" : "sum"); - auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); - Tensor indices_t(DT_INT32, TensorShape({2})); - indices_t.flat<int>()(0) = 0; - indices_t.flat<int>()(1) = 1; - auto tensors_expected = EvaluateNodes( - item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors_expected.size()); + auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); + Tensor indices_t(DT_INT32, TensorShape({2})); + indices_t.flat<int>()(0) = 0; + indices_t.flat<int>()(1) = 1; + auto tensors_expected = EvaluateNodes( + item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors_expected.size()); - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - // Run a second time to make sure the optimization is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - int found = 0; - for (const auto& node : output.node()) { - if (node.name() == "ConstantFolding/sum-reduction_indices") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^indices", node.input(0)); - EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); - } else if (node.name() == "sum") { - ++found; - EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); - } else if (node.name() == "indices") { - ++found; + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, + TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } else if (node.name() == "indices") { + ++found; + } } + EXPECT_EQ(3, found); + + auto tensors = EvaluateNodes(output, item.fetch, + {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); } - EXPECT_EQ(3, found); +} - auto tensors = EvaluateNodes(output, item.fetch, - {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); +TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) { + for (bool input_rank_known : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape( + PartialTensorShape({-1, -1}))) + : ops::Placeholder(s.WithOpName("input"), DT_FLOAT)); + Output indices = + ops::Placeholder(s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape( + PartialTensorShape({input_rank_known ? 1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("sum"); + + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + CompareGraphs(item.graph, output); + } } TEST_F(ConstantFoldingTest, LargeConstant) { |