aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 070b092d18..b851db14ec 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
XlaBuilder builder(TestName());
auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
- Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ PrecisionConfig precision;
+ // The left hand side of the convolution is numbers between 0 and 2304 which
+ // requires at least 11 mantissa bits and the DEFAULT precision config is
+ // allowed to round to bfloat16 which only has 7 mantissa bits.
+ precision.add_operand_precision(PrecisionConfig::HIGHEST);
+ precision.add_operand_precision(PrecisionConfig::DEFAULT);
+ Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
+ &precision);
ComputeAndCompare(&builder, {}, error_spec_);
}