diff options
author | Yao Zhang <yaozhang@google.com> | 2017-11-29 13:46:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-29 13:50:09 -0800 |
commit | 48347ee4105d78d8f36ba8645953b75cb5280c4c (patch) | |
tree | e71b8d3edbaee50fd21c7a00853b185c2e4c11f4 | |
parent | 19f62f62e5dab41b62b60ac66e7d07c09d55e17a (diff) |
Simplify const node creation.
PiperOrigin-RevId: 177357416
4 files changed, 169 insertions, 131 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5d9eb8e0b1..24e6f8847a 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -332,6 +332,11 @@ tf_cc_test( deps = [ ":layout_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 1b8046b787..ef4b015295 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -69,6 +69,8 @@ std::set<string> GetOpsFormatSupported() { return ops_format_supported; } +// TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled +// because of the worse performance in some cases. std::set<string> GetOpsFormatAgnostic() { std::set<string> ops_format_agnostic = {"Add", "AddN", @@ -88,7 +90,7 @@ std::set<string> GetOpsFormatAgnostic() { "Split", "SquaredDifference", "Squeeze", - "Sub"}; + /*"Sum",*/ "Sub"}; return ops_format_agnostic; } @@ -186,33 +188,6 @@ class GraphProcessor { return node; } - NodeDef* AddNodeReductionConst(const string& name, const string& device) { - NodeDef* node = graph_->add_node(); - node_map_->AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({3})); - std::vector<int> axis = {0, 2, 3}; - for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) { - tensor.flat<int>()(i) = axis[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - string device_name; - if (device.empty()) { - device_name = virtual_placer_.get_canonical_device_name(*node); - } else { - device_name = device; - } - node->set_device(device_name); - return node; - } - const VirtualPlacer& virtual_placer_; const std::unordered_set<string>& nodes_to_preserve_; GraphDef* graph_; @@ -370,10 +345,20 @@ class NodeProcessor : public GraphProcessor { LOG(ERROR) << "Failed to parse TensorProto."; } if (tensor.dims() == 1) { - int c = tensor.flat<int>()(3); - tensor.flat<int>()(3) = tensor.flat<int>()(2); - tensor.flat<int>()(2) = tensor.flat<int>()(1); - tensor.flat<int>()(1) = c; + if (tensor.flat<int>().size() == 4) { + int c = tensor.flat<int>()(3); + tensor.flat<int>()(3) = tensor.flat<int>()(2); + tensor.flat<int>()(2) = tensor.flat<int>()(1); + tensor.flat<int>()(1) = c; + } else if (tensor.flat<int>().size() == 3) { + tensor.flat<int>()(0) = 0; + tensor.flat<int>()(1) = 2; + tensor.flat<int>()(2) = 3; + } else { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Unsupported tensor size: ", + tensor.flat<int>().size())); + } } else if (tensor.dims() == 2) { for (int i = 0; i < 2; i++) { int c = tensor.matrix<int>()(3, i); @@ -394,7 +379,9 @@ class NodeProcessor : public GraphProcessor { Status UpdateAttrValueOfInput(int input_index) { auto input_node = node_map_->GetNode(node_->input(input_index)); // We created a copy of the node, so that we don't modify the original node, - // which might be used elsewhere. + // which might be used elsewhere. Note that this copy also copies the + // control dependency input in the case this node is inside a loop, + // to ensure added_node is in the same frame with node_. NodeDef* added_node = graph_->add_node(); *added_node = *input_node; string base_name = strings::StrCat(node_->name(), "-", input_node->name()); @@ -411,6 +398,14 @@ class NodeProcessor : public GraphProcessor { return input_pos; } + virtual std::set<int> GetOutputPos() const { + // For most nodes, no need to process control nodes or nodes that use an + // output other than the first output: only the first output is of + // 4D NCHW/NHWC format and thus relevant here. + std::set<int> output_pos = {0}; + return output_pos; + } + NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, const string& const_name, DataType data_type, const TensorShapeProto& input_shape, @@ -476,37 +471,28 @@ class NodeProcessor : public GraphProcessor { auto outputs = node_map_->GetOutputs(node_->name()); string const_name = GetOrAddNodePermNCHWToNHWC(); for (const auto& output : outputs) { - string base_name = strings::StrCat(node_->name(), "-", output->name()); - string node_name = - AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-"); - // TODO(yaozhang): handle the rare case where node A is connected to more - // than one input of node B. - auto it = std::find_if(output->mutable_input()->begin(), - output->mutable_input()->end(), - [this](const string& input) { - string node_name = NodeName(input); - return node_name.compare(node_->name()) == 0; - }); - if (it == output->mutable_input()->end()) { - return Status(error::INVALID_ARGUMENT, - strings::StrCat("Expect ", node_->name(), - " to be an input of ", output->name())); - } - int output_pos = NodePosition(*it); - // No need to process control nodes or nodes that use an output - // other than the first output: only the first output is of 4D NCHW/NHWC - // format and thus relevant here. - if (output_pos != 0) { - continue; + for (int i = 0; i < output->input_size(); i++) { + auto& input = *output->mutable_input(i); + int input_port; + string input_name = ParseNodeName(input, &input_port); + auto output_pos = GetOutputPos(); + if (input_name == node_->name() && + output_pos.find(input_port) != output_pos.end()) { + string base_name = + strings::StrCat(node_->name(), "-", output->name(), "-", i); + string node_name = + AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-"); + TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); + TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); + AddNodeTranspose( + node_name, input, const_name, node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(0), false); + input = node_name; + node_map_->AddOutput(node_->name(), node_name); + node_map_->AddOutput(node_name, output->name()); + } } - TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); - TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); - AddNodeTranspose( - node_name, node_->name(), const_name, node_->attr().at("T").type(), - node_->attr().at("_output_shapes").list().shape(0), false); - *it = node_name; - node_map_->UpdateOutput(node_->name(), output->name(), node_name); - node_map_->AddOutput(node_name, output->name()); + node_map_->RemoveOutput(node_->name(), output->name()); } return Status::OK(); } @@ -948,7 +934,7 @@ class ConcatProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - string concat_const_name = GetOrAddNodeConcatConst(); + string concat_const_name = AddNodeConcatConst()->name(); node_map_->AddOutput(concat_const_name, node_->name()); *node_->mutable_input(axis_node_pos_) = concat_const_name; return Status::OK(); @@ -956,8 +942,14 @@ class ConcatProcessor : public AgnosticNodeProcessor { bool IsAlongDimC() const { auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_)); + if (!IsConstant(*axis_node)) { + return false; + } if (axis_node->attr().find("value") != axis_node->attr().end()) { - return axis_node->attr().at("value").tensor().int_val(0) == 3; + auto tensor = axis_node->attr().at({"value"}).tensor(); + if (tensor.tensor_shape().dim_size() == 0 && tensor.int_val_size() == 1) { + return tensor.int_val(0) == 3; + } } return false; } @@ -965,28 +957,18 @@ class ConcatProcessor : public AgnosticNodeProcessor { int axis_node_pos_; private: - NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node, - const string& device) { - auto const_node = AddNodeConstScalar( - strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1); - // This is to ensure the concat node and the const node are - // in the same frame. - *const_node->add_input() = AsControlDependency(depended_node); - return const_node; - } - - string GetOrAddNodeConcatConst() { - string const_name; - if (is_in_frame_) { - int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0; - auto const_node = AddNodeConcatConst( - node_->name(), NodeName(node_->input(value_node_pos)), - node_->device()); - const_name = const_node->name(); - } else { - const_name = kConcatConst; - } - return const_name; + NodeDef* AddNodeConcatConst() { + auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_)); + // We created a copy of the node, so that we don't modify the original node, + // which might be used elsewhere. Note that this copy also copies the + // control dependency input in the case this node is inside a loop, + // to ensure added_node is in the same frame with node_. + auto added_node = graph_->add_node(); + *added_node = *axis_node; + added_node->set_name(strings::StrCat(kConcatConst, "-", node_->name())); + added_node->mutable_attr()->at({"value"}).mutable_tensor()->set_int_val(0, + 1); + return added_node; } }; @@ -1036,6 +1018,16 @@ class SplitProcessor : public AgnosticNodeProcessor { return input_pos; } + std::set<int> GetOutputPos() const override { + std::set<int> output_pos{0}; + if (HasAttribute(*node_, "num_split").ok()) { + for (int i = 1; i < node_->attr().at("num_split").i(); i++) { + output_pos.insert(i); + } + } + return output_pos; + } + Status CustomizedProcessing() override { string split_const_name = AddNodeSplitConst()->name(); node_map_->AddOutput(split_const_name, node_->name()); @@ -1073,7 +1065,7 @@ class SplitProcessor : public AgnosticNodeProcessor { // We created a copy of the node, so that we don't modify the original node, // which might be used elsewhere. Note that this copy also copies the // control dependency input in the case this node is inside a loop, - // to ensure added_node is in the same frame with the Split node. + // to ensure added_node is in the same frame with node_. NodeDef* added_node = graph_->add_node(); *added_node = *dim_node; added_node->set_name(strings::StrCat(kSplitConst, "-", node_->name())); @@ -1329,20 +1321,21 @@ class SumProcessor : public AgnosticNodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } - Status CustomizedProcessing() override { - node_map_->AddOutput(kReductionConst, node_->name()); - *node_->mutable_input(1) = GetOrAddNodeReductionConst(); - return Status::OK(); - } + Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); } private: bool IsAlongDimNHW() const { - NodeDef* node = node_map_->GetNode(node_->input(1)); + NodeDef* reduction_indices = node_map_->GetNode(node_->input(1)); + if (!IsConstant(*reduction_indices)) { + return false; + } Tensor tensor; - if (node->attr().find({"value"}) == node->attr().end()) { + if (reduction_indices->attr().find({"value"}) == + reduction_indices->attr().end()) { return false; } - auto success = tensor.FromProto(node->attr().at({"value"}).tensor()); + auto success = + tensor.FromProto(reduction_indices->attr().at({"value"}).tensor()); if (!success) { LOG(ERROR) << "Failed to parse TensorProto."; return false; @@ -1356,29 +1349,6 @@ class SumProcessor : public AgnosticNodeProcessor { } return false; } - - NodeDef* AddNodeReductionConst(const string& suffix, - const string& depended_node, - const string& device) { - auto const_node = GraphProcessor::AddNodeReductionConst( - strings::StrCat(kReductionConst, "-", suffix), device); - // This is to ensure the Sum node and the const node are in the - // same frame. - *const_node->add_input() = AsControlDependency(depended_node); - return const_node; - } - - string GetOrAddNodeReductionConst() { - string const_name; - if (is_in_frame_) { - auto const_node = AddNodeReductionConst( - node_->name(), NodeName(node_->input(0)), node_->device()); - const_name = const_node->name(); - } else { - const_name = kReductionConst; - } - return const_name; - } }; class DataLayoutOptimizer : GraphProcessor { @@ -1409,18 +1379,10 @@ class DataLayoutOptimizer : GraphProcessor { return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1}); } - NodeDef* AddNodeConcatConst() { - return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1); - } - NodeDef* AddNodeGatherAxisConst() { return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0); } - NodeDef* AddNodeReductionConst() { - return GraphProcessor::AddNodeReductionConst(kReductionConst, ""); - } - // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. Status Expand() { int node_size_original = graph_->node_size(); @@ -1474,9 +1436,7 @@ class DataLayoutOptimizer : GraphProcessor { if (graph_->node_size() > node_size_original) { NodeDef* n = AddNodePermNHWCToNCHW(); n = AddNodePermNCHWToNHWC(); - n = AddNodeConcatConst(); n = AddNodeGatherAxisConst(); - n = AddNodeReductionConst(); std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < graph_->node_size(); i++) { if (ops_format_agnostic.find(graph_->node(i).op()) != diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 8c89f6744b..e8f7b8ac3c 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -495,7 +495,80 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) { auto split_node = node_map.GetNode("split"); EXPECT_EQ(split_node->input(0), "i1"); EXPECT_EQ(split_node->input(1), - "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split"); + "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split-1"); +} + +TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto axis = ops::Const(s.WithOpName("axis"), 3); + auto split = ops::Split(s.WithOpName("split"), axis, conv, 2); + auto concat = + ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis); + auto o = ops::Identity(s.WithOpName("o"), concat); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto concat_node = node_map.GetNode("concat"); + EXPECT_EQ(concat_node->input(0), "split:1"); + EXPECT_EQ(concat_node->input(1), "split:1"); + EXPECT_EQ(concat_node->input(2), "split:1"); + EXPECT_EQ(concat_node->input(3), "LayoutOptimizerConcatConst-concat"); + auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat"); + EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); +} + +TEST_F(LayoutOptimizerTest, Concat) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto axis = ops::Const(s.WithOpName("axis"), 3); + auto split = ops::Split(s.WithOpName("split"), axis, conv, 2); + auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis); + auto o = ops::Identity(s.WithOpName("o"), concat); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto concat_node = node_map.GetNode("concat"); + EXPECT_EQ(concat_node->input(0), "split"); + EXPECT_EQ(concat_node->input(1), "split:1"); + EXPECT_EQ(concat_node->input(2), "LayoutOptimizerConcatConst-concat"); + auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat"); + EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); +} + +TEST_F(LayoutOptimizerTest, Sum) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto reduction_indices = + ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3}); + auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices); + auto o = ops::Identity(s.WithOpName("o"), sum); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled + // because of the worse performance in some cases. + /* + NodeMap node_map(&output); + auto sum_node = node_map.GetNode("sum"); + EXPECT_EQ(sum_node->input(0), "Conv2D"); + EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices"); + auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices"); + Tensor tensor; + EXPECT_TRUE( + tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor())); + Tensor tensor_expected(DT_INT32, {3}); + test::FillValues<int>(&tensor_expected, {0, 2, 3}); + test::ExpectTensorEqual<int>(tensor_expected, tensor); + */ } } // namespace diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 626e0502cb..50735fb567 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -190,7 +190,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Reshape-0', nodes) - self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1', + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) |