aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 13:27:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 13:34:40 -0700
commit86b4d8e65c62ff0be930e8c179f077cb83666aff (patch)
treeb9f52d5061c4ddae964c056790581f161308d644 /tensorflow/core/grappler
parent07a1ab7c442b6985f7ed615c99c55ab68a187ef7 (diff)
Don't crash on Pack nodes with no axis argument set.
PiperOrigin-RevId: 214035048
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc23
2 files changed, 19 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index cfbd298f11..ca5d3a6dfd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -2106,7 +2106,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
- const int axis = node->attr().at("axis").i();
+ const int axis =
+ node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2a19b3f95a..b09360a2c2 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3015,37 +3015,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
auto stack =
ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
ops::Stack::Axis(1));
+ auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- item.fetch.push_back("stack");
+ item.fetch = {"stack", "stack_no_axis"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ(7, output.node_size());
+ int found = 0;
for (const auto& node : output.node()) {
if (node.name() == "stack") {
- EXPECT_EQ("stack", node.name());
EXPECT_EQ("ExpandDims", node.op());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
+ } else if (node.name() == "stack_no_axis") {
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
+ ++found;
} else if (node.name() == "ConstantFolding/stack_const_axis") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
- std::vector<string> fetch = {"stack"};
+ std::vector<string> fetch = {"stack", "stack_no_axis"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
- EXPECT_EQ(1, tensors_expected.size());
- EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+ EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
}
// The test does not evalute the optimized and original graphs to check if their