aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_normalization_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 933cf873e0..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);