diff options
author | Yao Zhang <yaozhang@google.com> | 2017-12-01 15:59:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-01 16:03:09 -0800 |
commit | 2270ecc169025c7fa33edd09700abcd72a777373 (patch) | |
tree | c0779534010552f7986abc3b03feb60d8123acb1 | |
parent | 9232ef1bd38a6fbec2a62c2dcf373dcf0c01b6cb (diff) |
Remove SliceProcessorConcatOffset, which is not robust as it modifies nodes
which could be used elsewhere in the graph; SliceProcessorConcatOffset is a
historical implementation anyway, and the same functionality could be
provided by more recently developed SliceProcessor and SliceProcessorConst.
PiperOrigin-RevId: 177653218
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer_test.cc | 72 |
2 files changed, 76 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index e9436638f0..36e5047d61 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1136,7 +1136,7 @@ class SliceProcessor : public AgnosticNodeProcessor { string node_name = AddPrefixToNodeName(base_name, kPermVecNHWCToNCHW, "-"); TF_RETURN_IF_ERROR(HasAttribute(*node_, "Index")); - AddNodePermVec(node_name, node_->input(i), + AddNodePermVec(node_name, node_->input(i), node_->device(), node_->attr().at("Index").type(), true); node_map_->UpdateOutput(node_->input(i), node_->name(), node_name); node_map_->AddOutput(node_name, node_->name()); @@ -1194,10 +1194,12 @@ class SliceProcessor : public AgnosticNodeProcessor { } void AddNodePermVec(const string& node_name, const string& input_name, - DataType data_type, bool NHWCToNCHW) { + const string& device, DataType data_type, + bool NHWCToNCHW) { NodeDef* node = graph_->add_node(); node_map_->AddNode(node_name, node); node->set_name(node_name); + node->set_device(device); *node->add_input() = input_name; *node->add_input() = NHWCToNCHW ? GetOrAddNodePermNHWCToNCHW() : GetOrAddNodePermNCHWToNHWC(); @@ -1215,10 +1217,6 @@ class SliceProcessor : public AgnosticNodeProcessor { AttrValue attr_type_params; attr_type_params.set_type(data_type); node->mutable_attr()->insert({"Tparams", attr_type_params}); - - AttrValue attr_validate; - attr_validate.set_b(true); - node->mutable_attr()->insert({"validate_indices", attr_validate}); } }; @@ -1240,58 +1238,6 @@ class SliceProcessorConst : public AgnosticNodeProcessor { } }; -// Specialized SliceProcessor, used if the second input is ConcatOffset. An -// example use case is in the gradient computation of Concat for InceptionV3. -class SliceProcessorConcatOffset : public AgnosticNodeProcessor { - public: - explicit SliceProcessorConcatOffset(const OptimizeContext& opt_cxt) - : AgnosticNodeProcessor(opt_cxt) {} - - protected: - Status CustomizedProcessing() override { - auto maybe_concatoffset_node = - node_map_->GetNode(NodeName(node_->input(1))); - if (IsConcatOffset(*maybe_concatoffset_node)) { - auto maybe_axis_node = - node_map_->GetNode(maybe_concatoffset_node->input(0)); - NodeDef* axis_node; - if (IsConstant(*maybe_axis_node)) { - axis_node = maybe_axis_node; - // A FloorMod node might be added between ConcatOffset and the concat - // dimension const node to handle a negative dimension index -1, meaning - // the last dimension, which is consistent with the python's notation - // for negative index. - } else if (IsFloorMod(*maybe_axis_node)) { - axis_node = node_map_->GetNode(maybe_axis_node->input(0)); - } else { - return Status(error::INVALID_ARGUMENT, - strings::StrCat("Expect either Const or FloorMod for the " - "input 1 of ConcatOffset")); - } - // Need to process if the channel is at dimension 3, which indicates the - // NHWC format is being used. As multiple Slice nodes may share the same - // ConcatOffset node, the NHWC to NCHW conversion may have already - // been performed when processing other Slice nodes. - TF_RETURN_IF_ERROR(HasAttribute(*axis_node, "value")); - int concat_dim = axis_node->attr().at("value").tensor().int_val(0); - if (concat_dim == -1 || concat_dim == 3) { - // Update the dimension order for shape input nodes. Note that the input - // 2 of Slice also shares one of the shape nodes. - for (int i = 1; i < maybe_concatoffset_node->input_size(); i++) { - auto shape_node = - node_map_->GetNode(maybe_concatoffset_node->input(i)); - TF_RETURN_IF_ERROR(UpdateAttrValue(shape_node)); - } - // Set the channel dimension to 1, as we have converted the vector - // element order from NHWC to NCHW. - axis_node->mutable_attr()->at("value").mutable_tensor()->set_int_val(0, - 1); - } - } - return Status::OK(); - } -}; - class SqueezeProcessor : public AgnosticNodeProcessor { public: explicit SqueezeProcessor(const OptimizeContext& opt_cxt) @@ -1496,9 +1442,7 @@ class DataLayoutOptimizer : GraphProcessor { } else if (IsSlice(*node)) { auto input1 = node_map_->GetNode(NodeName(node->input(1))); auto input2 = node_map_->GetNode(NodeName(node->input(2))); - if (IsConcatOffset(*input1)) { - node_processor.reset(new SliceProcessorConcatOffset(opt_cxt)); - } else if (IsConstant(*input1) && IsConstant(*input2)) { + if (IsConstant(*input1) && IsConstant(*input2)) { node_processor.reset(new SliceProcessorConst(opt_cxt)); } else { node_processor.reset(new SliceProcessor(opt_cxt)); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 363b4c3fd8..0b906485e7 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -44,7 +44,7 @@ class LayoutOptimizerTest : public ::testing::Test { Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, const string& padding, const string& device) { - int batch_size = 128; + int batch_size = 8; int input_height = input_size; int input_width = input_size; int input_depth = 3; @@ -699,6 +699,76 @@ TEST_F(LayoutOptimizerTest, MulVectorAnd4D) { "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-mul-1"); } +TEST_F(LayoutOptimizerTest, SliceConst) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 5, 2, "VALID"); + auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4}); + auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4}); + auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size); + auto o = ops::Identity(s.WithOpName("o"), slice); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto slice_node = node_map.GetNode("slice"); + EXPECT_EQ(slice_node->input(0), "Conv2D"); + EXPECT_EQ(slice_node->input(1), "LayoutOptimizer-slice-begin"); + EXPECT_EQ(slice_node->input(2), "LayoutOptimizer-slice-size"); + + auto begin_const = node_map.GetNode("LayoutOptimizer-slice-begin"); + Tensor begin_tensor; + EXPECT_TRUE(begin_tensor.FromProto( + begin_const->mutable_attr()->at({"value"}).tensor())); + Tensor begin_tensor_expected(DT_INT32, {4}); + test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3}); + test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor); + + auto size_const = node_map.GetNode("LayoutOptimizer-slice-size"); + Tensor size_tensor; + EXPECT_TRUE(size_tensor.FromProto( + size_const->mutable_attr()->at({"value"}).tensor())); + Tensor size_tensor_expected(DT_INT32, {4}); + test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2}); + test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor); +} + +TEST_F(LayoutOptimizerTest, SliceNonConst) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 5, 2, "VALID"); + auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4}); + auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin); + auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4}); + auto isize = ops::Identity(s.WithOpName("isize"), size); + auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize); + auto o = ops::Identity(s.WithOpName("o"), slice); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto slice_node = node_map.GetNode("slice"); + EXPECT_EQ(slice_node->input(0), "Conv2D"); + EXPECT_EQ(slice_node->input(1), + "LayoutOptimizerPermVecNHWCToNCHW-slice-input1"); + EXPECT_EQ(slice_node->input(2), + "LayoutOptimizerPermVecNHWCToNCHW-slice-input2"); + + auto perm1 = + node_map.GetNode("LayoutOptimizerPermVecNHWCToNCHW-slice-input1"); + EXPECT_EQ(perm1->input(0), "ibegin"); + EXPECT_EQ(perm1->input(1), "LayoutOptimizerPermConstNHWCToNCHW"); + EXPECT_EQ(perm1->input(2), "LayoutOptimizerGatherAxisConst"); + + auto perm2 = + node_map.GetNode("LayoutOptimizerPermVecNHWCToNCHW-slice-input2"); + EXPECT_EQ(perm2->input(0), "isize"); + EXPECT_EQ(perm2->input(1), "LayoutOptimizerPermConstNHWCToNCHW"); + EXPECT_EQ(perm2->input(2), "LayoutOptimizerGatherAxisConst"); +} + } // namespace } // namespace grappler } // namespace tensorflow |