aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 10:22:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 10:26:07 -0800
commit48fc3bc388b09c67482db9751b6eab1d89ae140e (patch)
tree17fc5a2881129b33fe5aa03546b056f61171a27e
parent41a12df5de7d767a1a872348f3ba630350fcc78e (diff)
Implement partial constant folding for Concat.
PiperOrigin-RevId: 188501394
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc7
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h2
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc27
-rw-r--r--tensorflow/core/grappler/op_types.cc6
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc143
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc100
8 files changed, 261 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 243ca9121c..817247e379 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -1182,5 +1182,12 @@ GraphProperties::GetOutputProperties(const string& node_name) const {
return missing_properties_;
}
+void GraphProperties::ClearInputProperties(const string& node_name) {
+ input_properties_.erase(node_name);
+}
+void GraphProperties::ClearOutputProperties(const string& node_name) {
+ output_properties_.erase(node_name);
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 6fc53a7f2e..5aa4962072 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -64,6 +64,8 @@ class GraphProperties {
const string& node_name) const;
const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
const string& node_name) const;
+ void ClearInputProperties(const string& node_name);
+ void ClearOutputProperties(const string& node_name);
static void FillTensorPropertiesFromContext(
const shape_inference::ShapeHandle&, const DataType&,
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 5012069118..284d9d409b 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -113,6 +113,33 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
}
}
+TEST_F(GraphPropertiesTest, ClearProperties) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(true);
+ TF_CHECK_OK(s);
+
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "RandomStandardNormal") {
+ EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
+ const auto props = properties.GetOutputProperties(node.name());
+ properties.ClearOutputProperties(node.name());
+ const auto cleared_props = properties.GetOutputProperties(node.name());
+ EXPECT_TRUE(cleared_props.empty());
+ } else if (node.op() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ properties.ClearInputProperties(node.name());
+ const auto cleared_props = properties.GetInputProperties(node.name());
+ EXPECT_TRUE(cleared_props.empty());
+ }
+ }
+}
+
TEST_F(GraphPropertiesTest, DynamicProperties) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
cluster_->GetDeviceNames());
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 8cf1402ae8..ae71094079 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -72,6 +72,10 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
+bool IsConcat(const NodeDef& node) {
+ return node.op() == "Concat" || node.op() == "ConcatV2";
+}
+
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
@@ -213,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) {
return op == "NextIteration" || op == "RefNextIteration";
}
+bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
+
bool IsPad(const NodeDef& node) {
const auto& op = node.op();
return op == "Pad" || op == "PadV2";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index a7c33ef97b..690275da7c 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -40,6 +40,7 @@ bool IsCast(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
+bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
@@ -85,6 +86,7 @@ bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
bool IsPad(const NodeDef& node);
+bool IsPack(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 31dc1b73e1..4036ea3f16 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
}
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties,
+ GraphProperties* properties,
bool use_shape_info) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < output->node_size(); ++i) {
@@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
if (use_shape_info &&
(IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size 1)
bool replaceable = !shape.unknown_rank();
@@ -1649,7 +1649,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
graph_modified_ = true;
continue;
}
- if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+ if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
@@ -1667,8 +1667,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties.HasInputProperties(node->name()) &&
- properties.HasOutputProperties(node->name())) {
+ properties->HasInputProperties(node->name()) &&
+ properties->HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
@@ -1676,12 +1676,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties.GetOutputProperties(node->name())[0].shape();
+ properties->GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties.GetInputProperties(node->name())[1].shape();
+ properties->GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
@@ -1708,7 +1708,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
const TensorShapeProto& x_shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
@@ -1921,13 +1921,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative.
- // TODO(rmlarsen): Concat/Pack/ParallelConcat which are not commutative, so
- // we have to preserve order and can only push consecutive runs of constant
- // inputs into sub-nodes.
+ const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsAggregate(*node) && IsCommutative(*node) &&
- NumNonControlInputs(*node) > 2) {
+ num_non_control_inputs > 2) {
const int num_control_inputs =
- node->input_size() - NumNonControlInputs(*node);
+ node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
@@ -1943,7 +1941,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
- if (const_inputs.size() == NumNonControlInputs(*node) &&
+ if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
@@ -1953,7 +1951,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (1 < const_inputs.size() &&
- const_inputs.size() < NumNonControlInputs(*node) &&
+ const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = output->add_node();
*added_node = *node;
@@ -1987,8 +1985,121 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() -
num_control_inputs);
+ properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
graph_modified_ = true;
+ continue;
+ }
+ }
+
+ // Partial constant folding for Concat which is not commutative, so
+ // we have to preserve order and can only push consecutive runs of constant
+ // inputs into sub-nodes.
+ if (IsConcat(*node) && num_non_control_inputs > 3) {
+ bool already_optimized = false;
+ const string optimized = strings::StrCat(node->name(), "_partial_split_");
+ for (const string& input : node->input()) {
+ if (input.rfind(optimized) != string::npos) {
+ already_optimized = true;
+ break;
+ }
+ }
+ if (already_optimized) {
+ continue;
+ }
+ int axis_arg = -1;
+ int begin = 0;
+ int end = num_non_control_inputs;
+ if (node->op() == "Concat") {
+ begin = 1;
+ axis_arg = 0;
+ } else if (node->op() == "ConcatV2") {
+ end = num_non_control_inputs - 1;
+ axis_arg = num_non_control_inputs - 1;
+ } else {
+ continue;
+ }
+
+ const NodeDef* axis_arg_node =
+ node_map_->GetNode(NodeName(node->input(axis_arg)));
+ if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
+ // We cannot constant fold Concat unless we know the axis.
+ // Skip node.
+ continue;
+ }
+
+ // We search for consecutive runs of constant inputs in the range
+ // [begin:end[ and push then down into child nodes.
+ std::vector<std::pair<int, int>> constant_input_runs;
+ int first = begin;
+ int last = begin;
+ while (last < end) {
+ while (first < end && !IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(first))))) {
+ ++first;
+ }
+ // Invariant: node[first] is constant || first >= end.
+ last = first + 1;
+ while (last < end && IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(last))))) {
+ ++last;
+ }
+ // Invariant: node[last] is not constant || last >= end
+ // Discard intervals shorter than 2 elements.
+ if (first < end && (last - first) > 1) {
+ constant_input_runs.emplace_back(first, last);
+ }
+ first = last;
+ }
+
+ std::set<int> inputs_to_delete;
+ for (auto interval : constant_input_runs) {
+ // Push the constant inputs in the interval to a child node than can be
+ // constant folded.
+ const string new_node_name = OptimizedNodeName(
+ *node, strings::StrCat("_partial_split_", interval.first));
+ if (node_map_->NodeExists(new_node_name)) {
+ break;
+ }
+ NodeDef* added_node = output->add_node();
+ *added_node = *node;
+ added_node->set_name(new_node_name);
+ node_map_->AddNode(added_node->name(), added_node);
+ added_node->clear_input();
+ for (int i = interval.first; i < interval.second; ++i) {
+ added_node->add_input(node->input(i));
+ node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+ added_node->name());
+ if (i != interval.first) {
+ inputs_to_delete.insert(i);
+ }
+ }
+ added_node->add_input(node->input(axis_arg));
+ (*added_node->mutable_attr())["N"].set_i(interval.second -
+ interval.first);
+ node_map_->AddOutput(NodeName(node->input(axis_arg)),
+ added_node->name());
+
+ // Overwrite the first constant input with the result of the added
+ // child node.
+ node->set_input(interval.first, added_node->name());
+ node_map_->AddOutput(added_node->name(), node->name());
+ }
+ if (!constant_input_runs.empty()) {
+ graph_modified_ = true;
+ if (!inputs_to_delete.empty()) {
+ // Fix up the inputs to the original node.
+ std::vector<string> tmp(node->input().begin(), node->input().end());
+ node->clear_input();
+ for (int i = 0; i < tmp.size(); ++i) {
+ if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
+ node->add_input(tmp[i]);
+ }
+ }
+ (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
+ properties->ClearInputProperties(node->name());
+ }
+ continue;
}
}
}
@@ -2030,7 +2141,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
TF_RETURN_IF_ERROR(FoldGraph(output));
node_map_.reset(new NodeMap(output));
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 2fd59c7f9c..13ecfcd281 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer {
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 4b9770889f..9214695eb6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -188,20 +188,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
Output concat =
- ops::Concat(s.WithOpName("concat"),
- {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
- matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2},
- 0);
+ ops::Stack(s.WithOpName("stack"),
+ {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
+ matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"concat", "matmul3", "matmul4"};
+ item.fetch = {"stack", "matmul3", "matmul4"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(28, output.node_size());
+ EXPECT_EQ(27, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@@ -1626,19 +1625,19 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
- Output concat = ops::Concat(s.WithOpName("concat"),
- {acc0, acc1, acc2, acc3, acc4, acc5, acc6}, 0);
+ Output stack = ops::Stack(s.WithOpName("stack"),
+ {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"concat"};
+ item.fetch = {"stack"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(17, output.node_size());
+ EXPECT_EQ(16, output.node_size());
for (const NodeDef& node : output.node()) {
if (node.name() == "acc0") {
EXPECT_EQ("Const", node.op());
@@ -1696,7 +1695,86 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
}
}
-TEST_F(ConstantFoldingTest, IdenticalN) {
+TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
+ Scope s = Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+ Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
+ Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
+ Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
+ Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
+ Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
+ Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
+ Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
+ Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
+ Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
+ Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
+ Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
+ Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
+ "concat5", "concat6", "concat7", "concat8", "concat9"};
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(21, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "concat0") {
+ EXPECT_EQ("Const", node.op());
+ } else if (node.name() == "concat3") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
+ EXPECT_EQ("z", node.input(1));
+ EXPECT_EQ("axis", node.input(2));
+ } else if (node.name() == "concat5") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
+ EXPECT_EQ("axis", node.input(2));
+ } else if (node.name() == "concat7") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (node.name() == "concat8") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
+ EXPECT_EQ("y", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (node.name() == "concat9") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ EXPECT_EQ("y", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (StringPiece(node.name()).starts_with("ConstantFolding/")) {
+ EXPECT_EQ("Const", node.op());
+ } else {
+ EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
+ }
+ }
+
+ auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+ auto tensors = EvaluateNodes(output, {"concat0"});
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({})));