aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-11-29 13:46:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 13:50:09 -0800
commit48347ee4105d78d8f36ba8645953b75cb5280c4c (patch)
treee71b8d3edbaee50fd21c7a00853b185c2e4c11f4
parent19f62f62e5dab41b62b60ac66e7d07c09d55e17a (diff)
Simplify const node creation.
PiperOrigin-RevId: 177357416
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD5
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc218
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc75
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py2
4 files changed, 169 insertions, 131 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 5d9eb8e0b1..24e6f8847a 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -332,6 +332,11 @@ tf_cc_test(
deps = [
":layout_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 1b8046b787..ef4b015295 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -69,6 +69,8 @@ std::set<string> GetOpsFormatSupported() {
return ops_format_supported;
}
+// TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+// because of the worse performance in some cases.
std::set<string> GetOpsFormatAgnostic() {
std::set<string> ops_format_agnostic = {"Add",
"AddN",
@@ -88,7 +90,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Split",
"SquaredDifference",
"Squeeze",
- "Sub"};
+ /*"Sum",*/ "Sub"};
return ops_format_agnostic;
}
@@ -186,33 +188,6 @@ class GraphProcessor {
return node;
}
- NodeDef* AddNodeReductionConst(const string& name, const string& device) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
-
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({3}));
- std::vector<int> axis = {0, 2, 3};
- for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) {
- tensor.flat<int>()(i) = axis[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- string device_name;
- if (device.empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node);
- } else {
- device_name = device;
- }
- node->set_device(device_name);
- return node;
- }
-
const VirtualPlacer& virtual_placer_;
const std::unordered_set<string>& nodes_to_preserve_;
GraphDef* graph_;
@@ -370,10 +345,20 @@ class NodeProcessor : public GraphProcessor {
LOG(ERROR) << "Failed to parse TensorProto.";
}
if (tensor.dims() == 1) {
- int c = tensor.flat<int>()(3);
- tensor.flat<int>()(3) = tensor.flat<int>()(2);
- tensor.flat<int>()(2) = tensor.flat<int>()(1);
- tensor.flat<int>()(1) = c;
+ if (tensor.flat<int>().size() == 4) {
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(3) = tensor.flat<int>()(2);
+ tensor.flat<int>()(2) = tensor.flat<int>()(1);
+ tensor.flat<int>()(1) = c;
+ } else if (tensor.flat<int>().size() == 3) {
+ tensor.flat<int>()(0) = 0;
+ tensor.flat<int>()(1) = 2;
+ tensor.flat<int>()(2) = 3;
+ } else {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Unsupported tensor size: ",
+ tensor.flat<int>().size()));
+ }
} else if (tensor.dims() == 2) {
for (int i = 0; i < 2; i++) {
int c = tensor.matrix<int>()(3, i);
@@ -394,7 +379,9 @@ class NodeProcessor : public GraphProcessor {
Status UpdateAttrValueOfInput(int input_index) {
auto input_node = node_map_->GetNode(node_->input(input_index));
// We created a copy of the node, so that we don't modify the original node,
- // which might be used elsewhere.
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *input_node;
string base_name = strings::StrCat(node_->name(), "-", input_node->name());
@@ -411,6 +398,14 @@ class NodeProcessor : public GraphProcessor {
return input_pos;
}
+ virtual std::set<int> GetOutputPos() const {
+ // For most nodes, no need to process control nodes or nodes that use an
+ // output other than the first output: only the first output is of
+ // 4D NCHW/NHWC format and thus relevant here.
+ std::set<int> output_pos = {0};
+ return output_pos;
+ }
+
NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
const string& const_name, DataType data_type,
const TensorShapeProto& input_shape,
@@ -476,37 +471,28 @@ class NodeProcessor : public GraphProcessor {
auto outputs = node_map_->GetOutputs(node_->name());
string const_name = GetOrAddNodePermNCHWToNHWC();
for (const auto& output : outputs) {
- string base_name = strings::StrCat(node_->name(), "-", output->name());
- string node_name =
- AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
- // TODO(yaozhang): handle the rare case where node A is connected to more
- // than one input of node B.
- auto it = std::find_if(output->mutable_input()->begin(),
- output->mutable_input()->end(),
- [this](const string& input) {
- string node_name = NodeName(input);
- return node_name.compare(node_->name()) == 0;
- });
- if (it == output->mutable_input()->end()) {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Expect ", node_->name(),
- " to be an input of ", output->name()));
- }
- int output_pos = NodePosition(*it);
- // No need to process control nodes or nodes that use an output
- // other than the first output: only the first output is of 4D NCHW/NHWC
- // format and thus relevant here.
- if (output_pos != 0) {
- continue;
+ for (int i = 0; i < output->input_size(); i++) {
+ auto& input = *output->mutable_input(i);
+ int input_port;
+ string input_name = ParseNodeName(input, &input_port);
+ auto output_pos = GetOutputPos();
+ if (input_name == node_->name() &&
+ output_pos.find(input_port) != output_pos.end()) {
+ string base_name =
+ strings::StrCat(node_->name(), "-", output->name(), "-", i);
+ string node_name =
+ AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
+ AddNodeTranspose(
+ node_name, input, const_name, node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(0), false);
+ input = node_name;
+ node_map_->AddOutput(node_->name(), node_name);
+ node_map_->AddOutput(node_name, output->name());
+ }
}
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- AddNodeTranspose(
- node_name, node_->name(), const_name, node_->attr().at("T").type(),
- node_->attr().at("_output_shapes").list().shape(0), false);
- *it = node_name;
- node_map_->UpdateOutput(node_->name(), output->name(), node_name);
- node_map_->AddOutput(node_name, output->name());
+ node_map_->RemoveOutput(node_->name(), output->name());
}
return Status::OK();
}
@@ -948,7 +934,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- string concat_const_name = GetOrAddNodeConcatConst();
+ string concat_const_name = AddNodeConcatConst()->name();
node_map_->AddOutput(concat_const_name, node_->name());
*node_->mutable_input(axis_node_pos_) = concat_const_name;
return Status::OK();
@@ -956,8 +942,14 @@ class ConcatProcessor : public AgnosticNodeProcessor {
bool IsAlongDimC() const {
auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ if (!IsConstant(*axis_node)) {
+ return false;
+ }
if (axis_node->attr().find("value") != axis_node->attr().end()) {
- return axis_node->attr().at("value").tensor().int_val(0) == 3;
+ auto tensor = axis_node->attr().at({"value"}).tensor();
+ if (tensor.tensor_shape().dim_size() == 0 && tensor.int_val_size() == 1) {
+ return tensor.int_val(0) == 3;
+ }
}
return false;
}
@@ -965,28 +957,18 @@ class ConcatProcessor : public AgnosticNodeProcessor {
int axis_node_pos_;
private:
- NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node,
- const string& device) {
- auto const_node = AddNodeConstScalar(
- strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1);
- // This is to ensure the concat node and the const node are
- // in the same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeConcatConst() {
- string const_name;
- if (is_in_frame_) {
- int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0;
- auto const_node = AddNodeConcatConst(
- node_->name(), NodeName(node_->input(value_node_pos)),
- node_->device());
- const_name = const_node->name();
- } else {
- const_name = kConcatConst;
- }
- return const_name;
+ NodeDef* AddNodeConcatConst() {
+ auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ // We created a copy of the node, so that we don't modify the original node,
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
+ auto added_node = graph_->add_node();
+ *added_node = *axis_node;
+ added_node->set_name(strings::StrCat(kConcatConst, "-", node_->name()));
+ added_node->mutable_attr()->at({"value"}).mutable_tensor()->set_int_val(0,
+ 1);
+ return added_node;
}
};
@@ -1036,6 +1018,16 @@ class SplitProcessor : public AgnosticNodeProcessor {
return input_pos;
}
+ std::set<int> GetOutputPos() const override {
+ std::set<int> output_pos{0};
+ if (HasAttribute(*node_, "num_split").ok()) {
+ for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
+ output_pos.insert(i);
+ }
+ }
+ return output_pos;
+ }
+
Status CustomizedProcessing() override {
string split_const_name = AddNodeSplitConst()->name();
node_map_->AddOutput(split_const_name, node_->name());
@@ -1073,7 +1065,7 @@ class SplitProcessor : public AgnosticNodeProcessor {
// We created a copy of the node, so that we don't modify the original node,
// which might be used elsewhere. Note that this copy also copies the
// control dependency input in the case this node is inside a loop,
- // to ensure added_node is in the same frame with the Split node.
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *dim_node;
added_node->set_name(strings::StrCat(kSplitConst, "-", node_->name()));
@@ -1329,20 +1321,21 @@ class SumProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
- Status CustomizedProcessing() override {
- node_map_->AddOutput(kReductionConst, node_->name());
- *node_->mutable_input(1) = GetOrAddNodeReductionConst();
- return Status::OK();
- }
+ Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
private:
bool IsAlongDimNHW() const {
- NodeDef* node = node_map_->GetNode(node_->input(1));
+ NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
+ if (!IsConstant(*reduction_indices)) {
+ return false;
+ }
Tensor tensor;
- if (node->attr().find({"value"}) == node->attr().end()) {
+ if (reduction_indices->attr().find({"value"}) ==
+ reduction_indices->attr().end()) {
return false;
}
- auto success = tensor.FromProto(node->attr().at({"value"}).tensor());
+ auto success =
+ tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
if (!success) {
LOG(ERROR) << "Failed to parse TensorProto.";
return false;
@@ -1356,29 +1349,6 @@ class SumProcessor : public AgnosticNodeProcessor {
}
return false;
}
-
- NodeDef* AddNodeReductionConst(const string& suffix,
- const string& depended_node,
- const string& device) {
- auto const_node = GraphProcessor::AddNodeReductionConst(
- strings::StrCat(kReductionConst, "-", suffix), device);
- // This is to ensure the Sum node and the const node are in the
- // same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeReductionConst() {
- string const_name;
- if (is_in_frame_) {
- auto const_node = AddNodeReductionConst(
- node_->name(), NodeName(node_->input(0)), node_->device());
- const_name = const_node->name();
- } else {
- const_name = kReductionConst;
- }
- return const_name;
- }
};
class DataLayoutOptimizer : GraphProcessor {
@@ -1409,18 +1379,10 @@ class DataLayoutOptimizer : GraphProcessor {
return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
}
- NodeDef* AddNodeConcatConst() {
- return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1);
- }
-
NodeDef* AddNodeGatherAxisConst() {
return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0);
}
- NodeDef* AddNodeReductionConst() {
- return GraphProcessor::AddNodeReductionConst(kReductionConst, "");
- }
-
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Status Expand() {
int node_size_original = graph_->node_size();
@@ -1474,9 +1436,7 @@ class DataLayoutOptimizer : GraphProcessor {
if (graph_->node_size() > node_size_original) {
NodeDef* n = AddNodePermNHWCToNCHW();
n = AddNodePermNCHWToNHWC();
- n = AddNodeConcatConst();
n = AddNodeGatherAxisConst();
- n = AddNodeReductionConst();
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 8c89f6744b..e8f7b8ac3c 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -495,7 +495,80 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
auto split_node = node_map.GetNode("split");
EXPECT_EQ(split_node->input(0), "i1");
EXPECT_EQ(split_node->input(1),
- "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split");
+ "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split-1");
+}
+
+TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat =
+ ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split:1");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "split:1");
+ EXPECT_EQ(concat_node->input(3), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Concat) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Sum) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto reduction_indices =
+ ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
+ auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
+ auto o = ops::Identity(s.WithOpName("o"), sum);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+ // because of the worse performance in some cases.
+ /*
+ NodeMap node_map(&output);
+ auto sum_node = node_map.GetNode("sum");
+ EXPECT_EQ(sum_node->input(0), "Conv2D");
+ EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
+ auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
+ Tensor tensor;
+ EXPECT_TRUE(
+ tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor tensor_expected(DT_INT32, {3});
+ test::FillValues<int>(&tensor_expected, {0, 2, 3});
+ test::ExpectTensorEqual<int>(tensor_expected, tensor);
+ */
}
} // namespace
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 626e0502cb..50735fb567 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -190,7 +190,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Reshape-0',
nodes)
- self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1',
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1-0',
nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)