aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/layout_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc83
1 files changed, 76 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index be38ca1a69..566ea1d87d 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -36,8 +36,8 @@ void AddOutputShape(Node* node, const TensorShape& shape) {
class LayoutOptimizerTest : public ::testing::Test {
protected:
- Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size,
- const string& padding) {
+ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
+ const string& padding) {
int batch_size = 128;
int input_height = input_size;
int input_width = input_size;
@@ -65,11 +65,80 @@ class LayoutOptimizerTest : public ::testing::Test {
AddOutputShape(conv.node(), input_shape);
return conv;
}
+
+ Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
+ int filter_size, const string& padding) {
+ int batch_size = 128;
+ int input_height = input_size;
+ int input_width = input_size;
+ int input_depth = 3;
+ int filter_count = 2;
+ int stride = 1;
+ TensorShape input_sizes_shape({4});
+ Tensor input_data(DT_INT32, input_sizes_shape);
+ test::FillValues<int>(&input_data,
+ {batch_size, input_height, input_width, input_depth});
+ Output input_sizes =
+ ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
+ AddOutputShape(input_sizes.node(), input_sizes_shape);
+
+ TensorShape filter_shape(
+ {filter_size, filter_size, input_depth, filter_count});
+ Tensor filter_data(DT_FLOAT, filter_shape);
+ test::FillIota<float>(&filter_data, 1.0f);
+ Output filter =
+ ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
+ AddOutputShape(filter.node(), filter_shape);
+
+ int output_height = input_height;
+ int output_width = input_width;
+ TensorShape output_shape(
+ {batch_size, output_height, output_width, filter_count});
+ Tensor output_data(DT_FLOAT, output_shape);
+ test::FillIota<float>(&output_data, 1.0f);
+ Output output =
+ ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
+ AddOutputShape(output.node(), output_shape);
+
+ Output conv_backprop_input = ops::Conv2DBackpropInput(
+ s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
+ {1, stride, stride, 1}, padding);
+ TensorShape input_shape(
+ {batch_size, input_height, input_width, input_depth});
+ AddOutputShape(conv_backprop_input.node(), input_shape);
+ return conv_backprop_input;
+ }
+
+ Tensor GetAttrValue(const NodeDef& node) {
+ Tensor tensor;
+ CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
+ return tensor;
+ }
};
+TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME");
+ Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ optimizer.set_num_gpus(1);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ NodeMap node_map(&output);
+ auto input_sizes_node = node_map.GetNode(
+ AddPrefixToNodeName("InputSizes", "LayoutOptimizer", "-"));
+ CHECK(input_sizes_node);
+ auto input_sizes = GetAttrValue(*input_sizes_node);
+ Tensor input_sizes_expected(DT_INT32, {4});
+ test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
+ test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
+}
+
TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv(&s, 2, 1, "SAME");
+ auto conv = SimpleConv2D(&s, 2, 1, "SAME");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -84,7 +153,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv(&s, 2, 1, "SAME");
+ auto conv = SimpleConv2D(&s, 2, 1, "SAME");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -99,7 +168,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv(&s, 2, 2, "VALID");
+ auto conv = SimpleConv2D(&s, 2, 2, "VALID");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -114,7 +183,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv(&s, 2, 2, "SAME");
+ auto conv = SimpleConv2D(&s, 2, 2, "SAME");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -129,7 +198,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv(&s, 2, 3, "VALID");
+ auto conv = SimpleConv2D(&s, 2, 3, "VALID");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));