aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-05-04 08:48:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 10:09:41 -0700
commit65044bc25981e4e060ad5c34d9a520a0561775c3 (patch)
tree10c3fff9cae2722dbf5bf1942a1a94a12ad9c7c6
parent3bee923c93f9624ce3abf8d55173be66a7755545 (diff)
Add an option to not convert layout if GEMM is used internally in Conv2D, Conv2DBackpropFilter, and Conv2DBackpropInput, because in such cases, NHWC is usually faster than NCHW. The cost of enabling this option is the overhead of more non-cancellable layout conversion nodes. We added auto tuning to choose a better option by estimating the overhead using the number of added layout conversion nodes.
Don't Convert the layout for Sum, because reduction along dimension 0, 2, 3 (in NCHW) is about 10x slower than along 0, 1, 2 (in NHWC). Change: 155089805
-rw-r--r--tensorflow/core/grappler/op_types.cc10
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD17
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc188
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.h6
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc147
6 files changed, 340 insertions, 30 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index bafbcc200c..64bdd91077 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -18,6 +18,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsConcat(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "Concat" || op == "ConcatV2";
+}
+
bool IsDequeueOp(const NodeDef& node) {
static const std::set<std::string> dequeue_ops = {
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
@@ -30,6 +35,11 @@ bool IsPlaceholder(const NodeDef& node) {
return op == "Placeholder" || op == "PlaceholderV2";
}
+bool IsTranspose(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "Transpose";
+}
+
bool IsVariable(const NodeDef& node) {
const auto op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 2f58835628..4f2bb2bc05 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -21,8 +21,10 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsConcat(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
+bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e3b36c8412..5f30dfbaa2 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -205,11 +205,28 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
+cc_test(
+ name = "layout_optimizer_test",
+ srcs = ["layout_optimizer_test.cc"],
+ deps = [
+ ":layout_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
cc_library(
name = "meta_optimizer",
srcs = ["meta_optimizer.cc"],
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 9570ec17d0..5fec89b698 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -68,8 +69,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Slice",
"SquaredDifference",
"Squeeze",
- "Sub",
- "Sum"};
+ "Sub"};
return ops_format_agnostic;
}
@@ -110,9 +110,9 @@ class NodeProcessor {
}
protected:
- bool IsDimsN(NodeDef* node, int n) const {
- if (node->attr().find("_output_shapes") != node->attr().end()) {
- auto shape = node->attr().at("_output_shapes").list().shape(0);
+ bool IsDimsN(const NodeDef& node, int n) const {
+ if (node.attr().find("_output_shapes") != node.attr().end()) {
+ auto shape = node.attr().at("_output_shapes").list().shape(0);
if (shape.dim_size() == n) {
return true;
}
@@ -120,7 +120,7 @@ class NodeProcessor {
return false;
}
- bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); }
+ bool IsDimsFour(const NodeDef& node) const { return IsDimsN(node, 4); }
bool IsNHWC() const {
if (node_->attr().find("data_format") != node_->attr().end()) {
@@ -145,7 +145,7 @@ class NodeProcessor {
}
virtual bool ShouldProcess() const {
- return IsNHWC() && IsDimsFour(node_) && HasOutputs();
+ return IsNHWC() && IsDimsFour(*node_) && HasOutputs();
}
void UpdateAttrDataFormat() {
@@ -268,6 +268,8 @@ class NodeProcessor {
for (const auto& output : outputs) {
string node_name_NCHWToNHWC = strings::StrCat(
kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name());
+ // TODO (yaozhang): handle the rare case where node A is connected to more
+ // than one input of node B.
auto it = std::find_if(output->mutable_input()->begin(),
output->mutable_input()->end(),
[this](const string& input) {
@@ -341,7 +343,7 @@ class BiasAddGradProcessor : public NodeProcessor {
bool ShouldProcess() const override {
auto input = node_map_->GetNode(node_->input(0));
if (input) {
- if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) {
+ if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) {
return true;
}
}
@@ -351,13 +353,89 @@ class BiasAddGradProcessor : public NodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
};
-class Conv2DBackpropFilterProcessor : public NodeProcessor {
+class Conv2DProcessor : public NodeProcessor {
+ public:
+ Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool no_gemm)
+ : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
+ (!IsGemmUsed() || no_gemm_);
+ }
+
+ TensorShapeProto GetShape(const string& input_name) const {
+ string node_name;
+ int output_pos;
+ node_name = ParseNodeName(input_name, &output_pos);
+ NodeDef* node = node_map_->GetNode(node_name);
+ if (node->attr().find("_output_shapes") != node->attr().end()) {
+ return node->attr().at("_output_shapes").list().shape(output_pos);
+ }
+ TensorShapeProto shape;
+ return shape;
+ }
+
+ bool IsStrideOne() const {
+ if (node_->attr().find("strides") != node_->attr().end()) {
+ auto list = node_->attr().at("strides").list();
+ return list.i(1) == 1 && list.i(2) == 1;
+ }
+ return false;
+ }
+
+ bool IsValidPadding() const {
+ if (node_->attr().find("padding") != node_->attr().end()) {
+ auto padding = node_->attr().at("padding").s();
+ return padding == "VALID";
+ }
+ return false;
+ }
+
+ // The logic inside this function is based on the internal implementation of
+ // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
+ // needs to be updated accordingly if the internal implementation changes.
+ bool IsGemmUsed(const TensorShapeProto& filter_shape,
+ const TensorShapeProto& input_shape) const {
+ if (filter_shape.dim_size() == 4) {
+ if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
+ IsStrideOne()) {
+ return true;
+ }
+ }
+ if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
+ if (input_shape.dim(1).size() == filter_shape.dim(0).size() == 1 &&
+ input_shape.dim(2).size() == filter_shape.dim(1).size() &&
+ IsValidPadding()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ virtual bool IsGemmUsed() const {
+ auto filter_shape = GetShape(node_->input(1));
+ auto input_shape = GetShape(node_->input(0));
+ return IsGemmUsed(filter_shape, input_shape);
+ }
+
+ bool no_gemm_;
+};
+
+class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
public:
Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ NodeMap* node_map, bool no_gemm)
+ : Conv2DProcessor(graph, node, node_map, no_gemm) {}
protected:
+ bool IsGemmUsed() const override {
+ auto filter_shape = GetShape(node_->name());
+ auto input_shape = GetShape(node_->input(0));
+ return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
+ }
+
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos = {0, 2};
return input_pos;
@@ -370,17 +448,24 @@ class Conv2DBackpropFilterProcessor : public NodeProcessor {
void UpdateAttrShape() override {}
};
-class Conv2DBackpropInputProcessor : public NodeProcessor {
+class Conv2DBackpropInputProcessor : public Conv2DProcessor {
public:
Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ NodeMap* node_map, bool no_gemm)
+ : Conv2DProcessor(graph, node, node_map, no_gemm) {}
protected:
+ bool IsGemmUsed() const override {
+ auto filter_shape = GetShape(node_->input(1));
+ auto input_shape = GetShape(node_->name());
+ return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
+ }
+
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos = {2};
return input_pos;
}
+
Status CustomizedProcessing() override {
NodeDef* node = node_map_->GetNode(node_->input(0));
return UpdateAttrValue(node);
@@ -418,7 +503,7 @@ class AgnosticNodeProcessor : public NodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC();
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC();
}
bool IsNodeAfterNCHWToNHWC() const {
@@ -467,7 +552,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
(Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
Is4DOperateWithVector());
}
@@ -484,10 +569,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
auto input0 = node_map_->GetNode(node_->input(0));
auto input1 = node_map_->GetNode(node_->input(1));
if (input0 && input1) {
- return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
((n == 4)
- ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name()))
- : IsDimsN(input1, n));
+ ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name()))
+ : IsDimsN(*input1, n));
}
return false;
}
@@ -571,7 +656,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsAlongDimC();
}
@@ -739,7 +824,7 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
- return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ return IsDimsN(*node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsInputConvertible() && IsAlongDimHW();
}
@@ -790,7 +875,7 @@ class SumProcessor : public AgnosticNodeProcessor {
bool ShouldProcess() const override {
auto input0 = node_map_->GetNode(node_->input(0));
return HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
IsAlongDimNHW();
}
@@ -825,10 +910,21 @@ class SumProcessor : public AgnosticNodeProcessor {
}
};
+struct TuningConfig {
+ // If true, do not use the NHWC GEMM implementation. When filter size is
+ // one or filter size is equal to input image size,
+ // the NHWC implementation of Conv2D, Conv2DBackpropInput, and
+ // Conv2DBackpropFilter will use a specialized GEMM implementation, which is
+ // usually faster than the NCHW implementation. The downside is that this
+ // might result in more non-cancellable layout conversion nodes (implemented
+ // by the Tranpose op).
+ bool no_gemm;
+};
+
class DataLayoutOptimizer {
public:
- explicit DataLayoutOptimizer(GraphDef* graph)
- : graph_(graph), node_map_(graph_) {}
+ explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config)
+ : graph_(graph), node_map_(graph_), config_(config) {}
Status Optimize() {
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
@@ -908,12 +1004,15 @@ class DataLayoutOptimizer {
} else if (node->op().compare("BiasAddGrad") == 0) {
node_processor.reset(
new BiasAddGradProcessor(graph_, node, &node_map_));
- } else if (node->op().compare("Conv2DBackpropFilter") == 0) {
+ } else if (node->op().compare("Conv2D") == 0) {
node_processor.reset(
- new Conv2DBackpropFilterProcessor(graph_, node, &node_map_));
+ new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm));
+ } else if (node->op().compare("Conv2DBackpropFilter") == 0) {
+ node_processor.reset(new Conv2DBackpropFilterProcessor(
+ graph_, node, &node_map_, config_.no_gemm));
} else if (node->op().compare("Conv2DBackpropInput") == 0) {
- node_processor.reset(
- new Conv2DBackpropInputProcessor(graph_, node, &node_map_));
+ node_processor.reset(new Conv2DBackpropInputProcessor(
+ graph_, node, &node_map_, config_.no_gemm));
} else if (node->op().compare("FusedBatchNormGrad") == 0) {
node_processor.reset(
new FusedBatchNormGradProcessor(graph_, node, &node_map_));
@@ -1025,17 +1124,46 @@ class DataLayoutOptimizer {
GraphDef* graph_;
NodeMap node_map_;
+ TuningConfig config_;
};
+int GetNumTranspose(const GraphDef& graph) {
+ int number = 0;
+ for (const auto& node : graph.node()) {
+ if (IsTranspose(node)) {
+ number++;
+ }
+ }
+ LOG(INFO) << "Number of Transpose nodes: " << number;
+ return number;
+}
+
Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
- if (GetNumAvailableGPUs() < 1) {
+ if (num_gpus_ == 0) {
+ num_gpus_ = GetNumAvailableGPUs();
+ }
+ if (num_gpus_ < 1) {
// LayoutOptimizer is currently only tuned for GPU.
return Status::OK();
}
+
*output = item.graph;
- DataLayoutOptimizer layout_optimizer(output);
+ TuningConfig config;
+ config.no_gemm = false;
+ DataLayoutOptimizer layout_optimizer(output, config);
auto status = layout_optimizer.Optimize();
+
+ // This is based on an empirical observation that if the introduced Transpose
+ // nodes is more than 30, not using GEMM implementation would result in better
+ // performance.
+ if (status.ok() && GetNumTranspose(*output) > 30) {
+ *output = item.graph;
+ config.no_gemm = true;
+ DataLayoutOptimizer layout_optimizer(output, config);
+ status = layout_optimizer.Optimize();
+ }
+
if (!status.ok()) {
*output = item.graph;
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h
index 66dec17a35..1bd6f9544b 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h
@@ -29,11 +29,17 @@ class LayoutOptimizer : public GraphOptimizer {
string name() const override { return "layout"; };
+ // This is for testing only.
+ void set_num_gpus(int num_gpus) { num_gpus_ = num_gpus; };
+
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
+
+ private:
+ int num_gpus_ = 0;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
new file mode 100644
index 0000000000..be38ca1a69
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+void AddOutputShape(Node* node, const TensorShape& shape) {
+ std::vector<TensorShapeProto> output_shapes;
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+ output_shapes.push_back(shape_proto);
+ node->AddAttr("_output_shapes", output_shapes);
+}
+
+class LayoutOptimizerTest : public ::testing::Test {
+ protected:
+ Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size,
+ const string& padding) {
+ int batch_size = 128;
+ int input_height = input_size;
+ int input_width = input_size;
+ int input_depth = 3;
+ int filter_count = 2;
+ int stride = 1;
+ TensorShape input_shape(
+ {batch_size, input_height, input_width, input_depth});
+ Tensor input_data(DT_FLOAT, input_shape);
+ test::FillIota<float>(&input_data, 1.0f);
+ Output input =
+ ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
+ AddOutputShape(input.node(), input_shape);
+
+ TensorShape filter_shape(
+ {filter_size, filter_size, input_depth, filter_count});
+ Tensor filter_data(DT_FLOAT, filter_shape);
+ test::FillIota<float>(&filter_data, 1.0f);
+ Output filter =
+ ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
+ AddOutputShape(filter.node(), filter_shape);
+
+ Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter,
+ {1, stride, stride, 1}, padding);
+ AddOutputShape(conv.node(), input_shape);
+ return conv;
+ }
+};
+
+TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 1, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 1, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 2, "VALID");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_FALSE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 2, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_TRUE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv(&s, 2, 3, "VALID");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ EXPECT_TRUE(
+ node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow