aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21c3b..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);