diff options
author | 2018-09-21 13:27:50 -0700 | |
---|---|---|
committer | 2018-09-21 13:34:40 -0700 | |
commit | 86b4d8e65c62ff0be930e8c179f077cb83666aff (patch) | |
tree | b9f52d5061c4ddae964c056790581f161308d644 /tensorflow/core/grappler | |
parent | 07a1ab7c442b6985f7ed615c99c55ab68a187ef7 (diff) |
Don't crash on Pack nodes with no axis argument set.
PiperOrigin-RevId: 214035048
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 23 |
2 files changed, 19 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index cfbd298f11..ca5d3a6dfd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2106,7 +2106,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { Tensor axis_t(DT_INT32, TensorShape({})); NodeDef* axis_node = optimized_graph->add_node(); axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); - const int axis = node->attr().at("axis").i(); + const int axis = + node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i(); if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) .ok()) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 2a19b3f95a..b09360a2c2 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -3015,37 +3015,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) { auto stack = ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x}, ops::Stack::Axis(1)); + auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x}); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - item.fetch.push_back("stack"); + item.fetch = {"stack", "stack_no_axis"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(5, output.node_size()); + EXPECT_EQ(7, output.node_size()); + int found = 0; for (const auto& node : output.node()) { if (node.name() == "stack") { - EXPECT_EQ("stack", node.name()); EXPECT_EQ("ExpandDims", node.op()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1)); EXPECT_EQ("^y", node.input(2)); + ++found; + } else if (node.name() == "stack_no_axis") { + EXPECT_EQ("ExpandDims", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1)); + ++found; } else if (node.name() == "ConstantFolding/stack_const_axis") { EXPECT_EQ("Const", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^x", node.input(0)); + ++found; } } + EXPECT_EQ(found, 3); - std::vector<string> fetch = {"stack"}; + std::vector<string> fetch = {"stack", "stack_no_axis"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); auto tensors = EvaluateNodes(output, fetch); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); + EXPECT_EQ(2, tensors_expected.size()); + EXPECT_EQ(2, tensors.size()); EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape()); + EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape()); } // The test does not evalute the optimized and original graphs to check if their |