aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-12-01 15:59:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 16:03:09 -0800
commit2270ecc169025c7fa33edd09700abcd72a777373 (patch)
treec0779534010552f7986abc3b03feb60d8123acb1
parent9232ef1bd38a6fbec2a62c2dcf373dcf0c01b6cb (diff)
Remove SliceProcessorConcatOffset, which is not robust as it modifies nodes
which could be used elsewhere in the graph; SliceProcessorConcatOffset is a historical implementation anyway, and the same functionality could be provided by more recently developed SliceProcessor and SliceProcessorConst. PiperOrigin-RevId: 177653218
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc66
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc72
2 files changed, 76 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index e9436638f0..36e5047d61 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1136,7 +1136,7 @@ class SliceProcessor : public AgnosticNodeProcessor {
string node_name =
AddPrefixToNodeName(base_name, kPermVecNHWCToNCHW, "-");
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Index"));
- AddNodePermVec(node_name, node_->input(i),
+ AddNodePermVec(node_name, node_->input(i), node_->device(),
node_->attr().at("Index").type(), true);
node_map_->UpdateOutput(node_->input(i), node_->name(), node_name);
node_map_->AddOutput(node_name, node_->name());
@@ -1194,10 +1194,12 @@ class SliceProcessor : public AgnosticNodeProcessor {
}
void AddNodePermVec(const string& node_name, const string& input_name,
- DataType data_type, bool NHWCToNCHW) {
+ const string& device, DataType data_type,
+ bool NHWCToNCHW) {
NodeDef* node = graph_->add_node();
node_map_->AddNode(node_name, node);
node->set_name(node_name);
+ node->set_device(device);
*node->add_input() = input_name;
*node->add_input() = NHWCToNCHW ? GetOrAddNodePermNHWCToNCHW()
: GetOrAddNodePermNCHWToNHWC();
@@ -1215,10 +1217,6 @@ class SliceProcessor : public AgnosticNodeProcessor {
AttrValue attr_type_params;
attr_type_params.set_type(data_type);
node->mutable_attr()->insert({"Tparams", attr_type_params});
-
- AttrValue attr_validate;
- attr_validate.set_b(true);
- node->mutable_attr()->insert({"validate_indices", attr_validate});
}
};
@@ -1240,58 +1238,6 @@ class SliceProcessorConst : public AgnosticNodeProcessor {
}
};
-// Specialized SliceProcessor, used if the second input is ConcatOffset. An
-// example use case is in the gradient computation of Concat for InceptionV3.
-class SliceProcessorConcatOffset : public AgnosticNodeProcessor {
- public:
- explicit SliceProcessorConcatOffset(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- Status CustomizedProcessing() override {
- auto maybe_concatoffset_node =
- node_map_->GetNode(NodeName(node_->input(1)));
- if (IsConcatOffset(*maybe_concatoffset_node)) {
- auto maybe_axis_node =
- node_map_->GetNode(maybe_concatoffset_node->input(0));
- NodeDef* axis_node;
- if (IsConstant(*maybe_axis_node)) {
- axis_node = maybe_axis_node;
- // A FloorMod node might be added between ConcatOffset and the concat
- // dimension const node to handle a negative dimension index -1, meaning
- // the last dimension, which is consistent with the python's notation
- // for negative index.
- } else if (IsFloorMod(*maybe_axis_node)) {
- axis_node = node_map_->GetNode(maybe_axis_node->input(0));
- } else {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Expect either Const or FloorMod for the "
- "input 1 of ConcatOffset"));
- }
- // Need to process if the channel is at dimension 3, which indicates the
- // NHWC format is being used. As multiple Slice nodes may share the same
- // ConcatOffset node, the NHWC to NCHW conversion may have already
- // been performed when processing other Slice nodes.
- TF_RETURN_IF_ERROR(HasAttribute(*axis_node, "value"));
- int concat_dim = axis_node->attr().at("value").tensor().int_val(0);
- if (concat_dim == -1 || concat_dim == 3) {
- // Update the dimension order for shape input nodes. Note that the input
- // 2 of Slice also shares one of the shape nodes.
- for (int i = 1; i < maybe_concatoffset_node->input_size(); i++) {
- auto shape_node =
- node_map_->GetNode(maybe_concatoffset_node->input(i));
- TF_RETURN_IF_ERROR(UpdateAttrValue(shape_node));
- }
- // Set the channel dimension to 1, as we have converted the vector
- // element order from NHWC to NCHW.
- axis_node->mutable_attr()->at("value").mutable_tensor()->set_int_val(0,
- 1);
- }
- }
- return Status::OK();
- }
-};
-
class SqueezeProcessor : public AgnosticNodeProcessor {
public:
explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
@@ -1496,9 +1442,7 @@ class DataLayoutOptimizer : GraphProcessor {
} else if (IsSlice(*node)) {
auto input1 = node_map_->GetNode(NodeName(node->input(1)));
auto input2 = node_map_->GetNode(NodeName(node->input(2)));
- if (IsConcatOffset(*input1)) {
- node_processor.reset(new SliceProcessorConcatOffset(opt_cxt));
- } else if (IsConstant(*input1) && IsConstant(*input2)) {
+ if (IsConstant(*input1) && IsConstant(*input2)) {
node_processor.reset(new SliceProcessorConst(opt_cxt));
} else {
node_processor.reset(new SliceProcessor(opt_cxt));
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 363b4c3fd8..0b906485e7 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -44,7 +44,7 @@ class LayoutOptimizerTest : public ::testing::Test {
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding, const string& device) {
- int batch_size = 128;
+ int batch_size = 8;
int input_height = input_size;
int input_width = input_size;
int input_depth = 3;
@@ -699,6 +699,76 @@ TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-mul-1");
}
+TEST_F(LayoutOptimizerTest, SliceConst) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 5, 2, "VALID");
+ auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
+ auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
+ auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
+ auto o = ops::Identity(s.WithOpName("o"), slice);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto slice_node = node_map.GetNode("slice");
+ EXPECT_EQ(slice_node->input(0), "Conv2D");
+ EXPECT_EQ(slice_node->input(1), "LayoutOptimizer-slice-begin");
+ EXPECT_EQ(slice_node->input(2), "LayoutOptimizer-slice-size");
+
+ auto begin_const = node_map.GetNode("LayoutOptimizer-slice-begin");
+ Tensor begin_tensor;
+ EXPECT_TRUE(begin_tensor.FromProto(
+ begin_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor begin_tensor_expected(DT_INT32, {4});
+ test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
+ test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
+
+ auto size_const = node_map.GetNode("LayoutOptimizer-slice-size");
+ Tensor size_tensor;
+ EXPECT_TRUE(size_tensor.FromProto(
+ size_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor size_tensor_expected(DT_INT32, {4});
+ test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
+ test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
+}
+
+TEST_F(LayoutOptimizerTest, SliceNonConst) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 5, 2, "VALID");
+ auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
+ auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
+ auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
+ auto isize = ops::Identity(s.WithOpName("isize"), size);
+ auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
+ auto o = ops::Identity(s.WithOpName("o"), slice);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto slice_node = node_map.GetNode("slice");
+ EXPECT_EQ(slice_node->input(0), "Conv2D");
+ EXPECT_EQ(slice_node->input(1),
+ "LayoutOptimizerPermVecNHWCToNCHW-slice-input1");
+ EXPECT_EQ(slice_node->input(2),
+ "LayoutOptimizerPermVecNHWCToNCHW-slice-input2");
+
+ auto perm1 =
+ node_map.GetNode("LayoutOptimizerPermVecNHWCToNCHW-slice-input1");
+ EXPECT_EQ(perm1->input(0), "ibegin");
+ EXPECT_EQ(perm1->input(1), "LayoutOptimizerPermConstNHWCToNCHW");
+ EXPECT_EQ(perm1->input(2), "LayoutOptimizerGatherAxisConst");
+
+ auto perm2 =
+ node_map.GetNode("LayoutOptimizerPermVecNHWCToNCHW-slice-input2");
+ EXPECT_EQ(perm2->input(0), "isize");
+ EXPECT_EQ(perm2->input(1), "LayoutOptimizerPermConstNHWCToNCHW");
+ EXPECT_EQ(perm2->input(2), "LayoutOptimizerGatherAxisConst");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow