aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-11-27 22:31:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-27 22:34:31 -0800
commit102bfdfd830f4dab6e00371e63a82561e1246518 (patch)
tree8dd5143e0a86adfaac353a3a24824a7941c04a13 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent8781d69b2e619e64555cb00b13783a7eee524b81 (diff)
[XLA] Separate input and output spatial dimensions for convolution
This lets us reason about input spatial dimensions as distinct from output spatial dimensions. By doing this, it opens up more opportunities for assigning more interesting, different, layouts for the activations and the output. PiperOrigin-RevId: 177117140
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc24
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index d12f7bd145..be93c879c0 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -395,8 +395,10 @@ TEST_F(ShapeInferenceTest, Convolve) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
@@ -437,8 +439,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
@@ -480,8 +484,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
@@ -524,8 +530,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dnums.set_output_batch_dimension(3);
dnums.set_input_feature_dimension(2);
dnums.set_output_feature_dimension(2);
- dnums.add_spatial_dimensions(0);
- dnums.add_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(0);
+ dnums.add_output_spatial_dimensions(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
dnums.set_kernel_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);