diff options
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_test.cc | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 070b092d18..b851db14ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs); auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } |