aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index cc92e58ef8..864ed43118 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(2);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(1);
dim1->set_base_dilation(2);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_stride(2);
dim1->set_padding_low(1);
dim1->set_padding_high(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));