aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-08-29 23:11:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 23:16:29 -0700
commit5452e5a49f99b9ea9f095d71a10e794b8d70fc04 (patch)
treed5a09da125dceb3ec802a38db754e06d06e7b042 /tensorflow/compiler/xla/service
parent786e469ed0b74e2175e9f4b3d1ac7531c65017b0 (diff)
Convert a couple more test files to HloVerifiedTestBase, and add default arguments to the constructor to remove some boilerplate.
PiperOrigin-RevId: 210855509
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
18 files changed, 58 insertions, 130 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f8e0ed440d..cd8817ac16 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1270,6 +1270,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
@@ -2114,6 +2115,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 917ed86b69..cbce98ef13 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -54,12 +54,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {
- public:
- AlgebraicSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -3013,12 +3008,7 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {
- public:
- DotOfConcatSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3191,12 +3181,7 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {
- public:
- DotOfGatherSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index b342acb025..38f1a5d3a6 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,12 +24,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {
- public:
- BatchDotSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 25e85d4747..e9751cc269 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -81,9 +81,6 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
class BufferAssignmentTest : public HloVerifiedTestBase {
protected:
- BufferAssignmentTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
~BufferAssignmentTest() override {}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
@@ -1753,10 +1750,6 @@ ENTRY main {
class WhileBufferAssignmentTest : public HloVerifiedTestBase {
protected:
- WhileBufferAssignmentTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
const string& name) {
auto builder = HloComputation::Builder(name);
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 6c477da038..c43a31b167 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,10 +39,6 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
- ConditionalSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index 37d1895d41..e727ba49cb 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,11 +26,6 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
- public:
- DefuserTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 104af48c82..5c92b0dcb8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,12 +29,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class PadForTensorCoresTest : public HloVerifiedTestBase {
- public:
- PadForTensorCoresTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class PadForTensorCoresTest : public HloVerifiedTestBase {};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index da94ab5346..54abe3345d 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -39,15 +39,17 @@ namespace {
using ::testing::UnorderedElementsAre;
-class HloAliasAnalysisTest : public HloTestBase {
+class HloAliasAnalysisTest : public HloVerifiedTestBase {
protected:
- HloAliasAnalysisTest() : module_(CreateNewModule()) {}
+ HloAliasAnalysisTest() : HloVerifiedTestBase() {
+ module_ = CreateNewModule();
+ }
// Run alias analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
HloAliasAnalysis& RunAnalysis() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis");
- analysis_ = HloAliasAnalysis::Run(module_.get(),
+ analysis_ = HloAliasAnalysis::Run(module_,
/*fusion_can_share_buffer=*/nullptr)
.ConsumeValueOrDie();
return *analysis_;
@@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase {
// never occurs, but HLO graphs with interference can be explicitly
// constructed.
bool AnyValuesInSameBufferInterfere() {
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
for (const HloBuffer& buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
@@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return false;
}
- std::unique_ptr<HloModule> module_;
+ HloModule* module_;
std::unique_ptr<HloAliasAnalysis> analysis_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
@@ -461,7 +463,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
module_->AddEntryComputation(builder.Build());
FlattenCallGraph flattener;
- TF_ASSERT_OK(flattener.Run(module_.get()).status());
+ TF_ASSERT_OK(flattener.Run(module_).status());
const HloAliasAnalysis& analysis = RunAnalysis();
@@ -835,7 +837,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) {
const HloAliasAnalysis& analysis = RunAnalysis();
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
@@ -877,7 +879,7 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
{
// Dependency ordering should interfere because the negate and while are
// unordered.
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
@@ -888,13 +890,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
sequence[condition] = {cond_param, cond_root};
{
sequence[entry] = {init, xla_while, negate, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(module_, sequence);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
{
sequence[entry] = {init, negate, xla_while, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(module_, sequence);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index a8de285d16..662f008205 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -19,19 +19,20 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
using tensorflow::gtl::ArraySlice;
-class HloCreationUtilsTest : public HloTestBase {
+class HloCreationUtilsTest : public HloVerifiedTestBase {
protected:
- std::unique_ptr<HloModule> CreateModuleWithProgramShape(
- PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims, HloInstruction** param,
- HloComputation** entry_computation) {
+ HloModule* CreateModuleWithProgramShape(PrimitiveType primitive_type,
+ ArraySlice<int64> input_shape_dims,
+ ArraySlice<int64> output_shape_dims,
+ HloInstruction** param,
+ HloComputation** entry_computation) {
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
Shape output_shape =
ShapeUtil::MakeShape(primitive_type, output_shape_dims);
@@ -48,10 +49,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
CollapseFirstNDims(param, 1));
@@ -68,7 +69,7 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, &param,
&entry_computation);
@@ -93,10 +94,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{1, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
PrependDegenerateDims(param, 1));
@@ -114,7 +115,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, &param,
&entry_computation);
@@ -135,10 +136,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{1, 1},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
PrependDegenerateDims(param, 2));
@@ -155,7 +156,7 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, &param,
&entry_computation);
@@ -177,10 +178,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{6},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zero_padded_param,
@@ -198,10 +199,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
@@ -219,10 +220,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- F32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(F32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index c8e0a9e289..974ab94467 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -29,11 +29,6 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
- public:
- HloDomainTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index c3af15c6a8..e3eb60a851 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1219,12 +1219,7 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
- public:
- HloEvaluatorPreciseReduceTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 50c04b055b..81290ccd63 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -41,10 +41,6 @@ using ::testing::UnorderedElementsAre;
class HloInstructionTest : public HloVerifiedTestBase {
protected:
- HloInstructionTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index fc1f81bdd2..0cac210c24 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -34,6 +34,8 @@ namespace {
using ::testing::HasSubstr;
+// This class cannot be converted to use HloVerifiedTestBase. It explicitly
+// uses HloTestBase to create and test malformed HLOs.
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index df88587492..f85d31d522 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,11 +26,6 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
- public:
- ImplicitBroadcastRemoverTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index c34c32f7d3..2d03aebc1a 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,11 +22,6 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
- public:
- IndexedArrayAnalysisTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index a395dd5333..fcf269eee9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -34,12 +34,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class ReshapeMoverTest : public HloVerifiedTestBase {
- public:
- ReshapeMoverTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class ReshapeMoverTest : public HloVerifiedTestBase {};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index e14014b961..32e69c335b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,10 +28,6 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
- WhileLoopInvariantCodeMotionTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index cfe4104f6d..1c892ba179 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -28,11 +28,6 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
- public:
- WhileLoopSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);