diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/bfloat16_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/bfloat16_test.cc | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 65589b0d6a..e9728e636f 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4<bfloat16>( {{{{static_cast<bfloat16>(-1.6875f)}, {static_cast<bfloat16>(-2.04f)}}, {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}}, {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}}, - {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}) - .get(), + {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}), LiteralUtil::CreateR1<bfloat16>( - {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}) - .get(), + {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}), LiteralUtil::CreateR1<bfloat16>( - {static_cast<bfloat16>(5), static_cast<bfloat16>(5)}) - .get()}); + {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4<bfloat16>( {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}}, {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}}, {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}}, - {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}) - .get(), + {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}), LiteralUtil::CreateR1<bfloat16>( - {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}) - .get(), + {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}), LiteralUtil::CreateR1<bfloat16>( - {static_cast<bfloat16>(16), static_cast<bfloat16>(20)}) - .get()}); + {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace |