aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/bfloat16_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/bfloat16_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc26
1 files changed, 10 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 65589b0d6a..e9728e636f 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
- {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
- .get(),
+ {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
- .get()});
+ {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
- {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
- .get(),
+ {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
- .get()});
+ {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace