diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/layout_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer_test.cc | 83 |
1 files changed, 76 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index be38ca1a69..566ea1d87d 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -36,8 +36,8 @@ void AddOutputShape(Node* node, const TensorShape& shape) { class LayoutOptimizerTest : public ::testing::Test { protected: - Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size, - const string& padding) { + Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, + const string& padding) { int batch_size = 128; int input_height = input_size; int input_width = input_size; @@ -65,11 +65,80 @@ class LayoutOptimizerTest : public ::testing::Test { AddOutputShape(conv.node(), input_shape); return conv; } + + Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size, + int filter_size, const string& padding) { + int batch_size = 128; + int input_height = input_size; + int input_width = input_size; + int input_depth = 3; + int filter_count = 2; + int stride = 1; + TensorShape input_sizes_shape({4}); + Tensor input_data(DT_INT32, input_sizes_shape); + test::FillValues<int>(&input_data, + {batch_size, input_height, input_width, input_depth}); + Output input_sizes = + ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data)); + AddOutputShape(input_sizes.node(), input_sizes_shape); + + TensorShape filter_shape( + {filter_size, filter_size, input_depth, filter_count}); + Tensor filter_data(DT_FLOAT, filter_shape); + test::FillIota<float>(&filter_data, 1.0f); + Output filter = + ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); + AddOutputShape(filter.node(), filter_shape); + + int output_height = input_height; + int output_width = input_width; + TensorShape output_shape( + {batch_size, output_height, output_width, filter_count}); + Tensor output_data(DT_FLOAT, output_shape); + test::FillIota<float>(&output_data, 1.0f); + Output output = + ops::Const(s->WithOpName("Output"), Input::Initializer(output_data)); + AddOutputShape(output.node(), output_shape); + + Output conv_backprop_input = ops::Conv2DBackpropInput( + s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output, + {1, stride, stride, 1}, padding); + TensorShape input_shape( + {batch_size, input_height, input_width, input_depth}); + AddOutputShape(conv_backprop_input.node(), input_shape); + return conv_backprop_input; + } + + Tensor GetAttrValue(const NodeDef& node) { + Tensor tensor; + CHECK(tensor.FromProto(node.attr().at({"value"}).tensor())); + return tensor; + } }; +TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + auto input_sizes_node = node_map.GetNode( + AddPrefixToNodeName("InputSizes", "LayoutOptimizer", "-")); + CHECK(input_sizes_node); + auto input_sizes = GetAttrValue(*input_sizes_node); + Tensor input_sizes_expected(DT_INT32, {4}); + test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7}); + test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes); +} + TEST_F(LayoutOptimizerTest, FilterSizeIsOne) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto conv = SimpleConv(&s, 2, 1, "SAME"); + auto conv = SimpleConv2D(&s, 2, 1, "SAME"); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -84,7 +153,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeIsOne) { TEST_F(LayoutOptimizerTest, FilterSizeNotOne) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto conv = SimpleConv(&s, 2, 1, "SAME"); + auto conv = SimpleConv2D(&s, 2, 1, "SAME"); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -99,7 +168,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeNotOne) { TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto conv = SimpleConv(&s, 2, 2, "VALID"); + auto conv = SimpleConv2D(&s, 2, 2, "VALID"); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -114,7 +183,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) { TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto conv = SimpleConv(&s, 2, 2, "SAME"); + auto conv = SimpleConv2D(&s, 2, 2, "SAME"); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -129,7 +198,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) { TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto conv = SimpleConv(&s, 2, 3, "VALID"); + auto conv = SimpleConv2D(&s, 2, 3, "VALID"); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); |