aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eli Bendersky <eliben@google.com>2017-06-09 08:46:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 08:50:40 -0700
commit9641b8edab3113f5bb83b5491de747dc9a43fe01 (patch)
tree052fad13d00142dc420d8bb38f46732d9912d6c2
parentcade141580c76b41ba71bdc4b019722e674ab954 (diff)
[XLA] Switch HloTestBase-based tests to use new debug options flag.
PiperOrigin-RevId: 158522608
-rw-r--r--tensorflow/compiler/xla/service/BUILD19
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc56
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc28
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc210
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD8
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc62
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc40
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc99
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc77
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc81
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc47
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc23
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc18
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h5
40 files changed, 732 insertions, 616 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 71629763da..1fc9e46f7a 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -198,7 +198,6 @@ cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -229,7 +228,6 @@ cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -263,7 +261,6 @@ cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -303,7 +300,6 @@ cc_test(
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -651,7 +647,6 @@ cc_test(
":liveness_util",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -740,7 +735,6 @@ cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -798,7 +792,6 @@ cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -832,7 +825,6 @@ cc_test(
":hlo_matchers",
":instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -870,10 +862,8 @@ cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
],
)
@@ -907,7 +897,6 @@ cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -1123,7 +1112,6 @@ cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -1308,7 +1296,6 @@ cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -1391,7 +1378,6 @@ cc_test(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -1414,7 +1400,6 @@ cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -1484,7 +1469,6 @@ cc_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -1521,7 +1505,6 @@ cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -1608,7 +1591,6 @@ cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/core:test_main",
],
)
@@ -1635,7 +1617,6 @@ cc_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 19583433db..aac7a6a6b1 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -33,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/test.h"
namespace op = xla::testing::opcode_matchers;
@@ -1711,18 +1709,5 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
} // namespace xla
int main(int argc, char** argv) {
- std::vector<tensorflow::Flag> flag_list;
- xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
- const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
- if (!parse_result) {
- LOG(ERROR) << "\n" << usage;
- return 2;
- }
- testing::InitGoogleTest(&argc, argv);
- if (argc > 1) {
- LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
- return 2;
- }
- return RUN_ALL_TESTS();
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8463ece124..a3b057a257 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -254,7 +254,7 @@ TEST_F(BufferAssignmentTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
@@ -272,7 +272,7 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
@@ -290,7 +290,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
@@ -317,7 +317,7 @@ TEST_F(BufferAssignmentTest, Basic) {
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kSubtract, add, param1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
@@ -367,7 +367,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
@@ -402,7 +402,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
// param0[100x10] ---> (map x+1)
//
// Builds the map function.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto map_computation =
module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
auto inner_last = map_computation->root_instruction();
@@ -457,7 +457,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
// out-of-order reductions could overwrite an element before a use.)
//
// param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3)
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto reduce_computation =
module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
@@ -508,7 +508,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
// const4[f32[4]] --- tuple --- while[condition, body]
//
// Builds the nested condition and body.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto condition_computation =
module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
auto body_computation =
@@ -588,7 +588,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
auto neg = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -617,7 +617,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -650,7 +650,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -687,7 +687,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -722,7 +722,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -754,7 +754,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -792,7 +792,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -807,7 +807,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
// Verify that buffers for embedded computations are properly marked as
// thread-local and that embedded parameters are not marked as
// is_entry_computation_parameter.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto vec_shape = ShapeUtil::MakeShape(F32, {42});
auto scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -884,7 +884,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
ShapeUtil::MakeShape(S32, {42})}),
"param0"));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -919,7 +919,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -962,7 +962,7 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
LiteralUtil::CreateR0<int64>(1).get()})));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -976,7 +976,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
ShapeUtil::MakeShape(S32, {101})}),
/*operands=*/{}, /*custom_call_target=*/"foo_function"));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -991,7 +991,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
// Test a computation which returns a tuple call value.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto elem_shape = f32vec4_;
auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
@@ -1030,7 +1030,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
// B: call(C, param)
// C: call(D, param)
// D: param
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto elem_shape = f32vec4_;
auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
@@ -1101,7 +1101,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) {
auto bitcast = builder.AddInstruction(
HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -1127,7 +1127,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -1165,7 +1165,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape, HloOpcode::kCopy, tuple_element));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get());
@@ -1201,7 +1201,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1));
// Run buffer assignment with alignment=1.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
@@ -1498,3 +1498,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index c2184cc680..427e4e492c 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -116,7 +116,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) {
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
@@ -163,7 +163,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
HloComputation* entry = module->AddEntryComputation(builder.Build());
SequentialHloOrdering::HloModuleSequence sequence;
@@ -212,7 +212,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) {
auto reverse =
builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
@@ -246,7 +246,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) {
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
@@ -288,7 +288,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
SequentialHloOrdering::HloModuleSequence module_sequence;
@@ -330,7 +330,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
auto outer_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
@@ -350,7 +350,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
TEST_F(BufferLivenessTest, EmbeddedComputation) {
// Test MaybeLiveOut and MayInterfere for embedded computation.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
auto embedded_param = embedded_builder.AddInstruction(
@@ -407,7 +407,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0->shape(), tuple_constant, 0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
@@ -470,7 +470,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
auto tuple_root =
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
@@ -531,7 +531,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
auto tuple_root =
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
@@ -604,7 +604,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
auto tuple_root = builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(BuildDummyComputation());
auto* computation = module->AddEmbeddedComputation(builder.Build());
// Create fusion instruction based on number of tuple element 1 users.
@@ -732,7 +732,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
auto tuple_root = builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
@@ -785,3 +785,7 @@ TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index ab0ea47d02..e276473c90 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -92,10 +92,10 @@ class CallGraphTest : public HloTestBase {
TEST_F(CallGraphTest, SingletonComputation) {
// Test the call graph of a module with a single computation.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ module->AddEntryComputation(MakeScalarComputation());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(1, call_graph->nodes().size());
const CallGraphNode& node = call_graph->GetNode(computation);
EXPECT_EQ(computation, node.computation());
@@ -109,13 +109,13 @@ TEST_F(CallGraphTest, SingletonComputation) {
TEST_F(CallGraphTest, UnreachableComputation) {
// Test the call graph of a module with an entry computation and an
// unreachable computation.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(MakeScalarComputation());
+ module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
+ module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -131,13 +131,13 @@ TEST_F(CallGraphTest, UnreachableComputation) {
TEST_F(CallGraphTest, ParallelComputation) {
// Test a call graph of a module with an entry computation which calls another
// computation in a parallel context via kMap.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* map_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
- HloComputation* entry_computation = module.AddEntryComputation(
+ module->AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -160,13 +160,13 @@ TEST_F(CallGraphTest, ParallelComputation) {
TEST_F(CallGraphTest, SequentialComputations) {
// Test a call graph of a module with an entry computation which calls another
// computation in a sequential context via kCall.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* called_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
- HloComputation* entry_computation = module.AddEntryComputation(
+ module->AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -189,9 +189,9 @@ TEST_F(CallGraphTest, SequentialComputations) {
TEST_F(CallGraphTest, ContextBothComputations) {
// Test a call graph of a module with an entry computation which calls another
// computation in both a parallel and sequential context.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* subcomputation =
- module.AddEmbeddedComputation(MakeScalarComputation());
+ module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
@@ -201,9 +201,9 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloInstruction* map = builder.AddInstruction(
HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -241,12 +241,12 @@ TEST_F(CallGraphTest, ComplexGraph) {
// c
//
// Calls are made via kCall, kWhile, and kMap instructions.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* cond_computation =
- module.AddEmbeddedComputation(MakeConditionComputation());
+ module->AddEmbeddedComputation(MakeConditionComputation());
HloComputation* c_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
- HloComputation* b_computation = module.AddEmbeddedComputation(
+ module->AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* b_computation = module->AddEmbeddedComputation(
MakeMappingComputation(c_computation, /*callsites=*/1));
HloComputation* a_computation;
@@ -258,7 +258,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, b_computation, call));
- a_computation = module.AddEmbeddedComputation(builder.Build());
+ a_computation = module->AddEmbeddedComputation(builder.Build());
}
HloComputation* entry_computation;
@@ -268,10 +268,10 @@ TEST_F(CallGraphTest, ComplexGraph) {
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, a_computation, param0));
- entry_computation = module.AddEntryComputation(builder.Build());
+ entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(5, call_graph->nodes().size());
// Entry computation has one while instruction calling two computations
@@ -318,10 +318,10 @@ TEST_F(CallGraphTest, ComplexGraph) {
TEST_F(CallGraphTest, VisitSingletonComputation) {
// Test the call graph visitor with a call graph with a single node.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ module->AddEntryComputation(MakeScalarComputation());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -333,12 +333,12 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
TEST_F(CallGraphTest, VisitUnreachableComputation) {
// Test the call graph visitor with a call graph with an unreachable node.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(MakeScalarComputation());
+ module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ module->AddEmbeddedComputation(MakeScalarComputation());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
// Test visitation of only reachable nodes.
{
@@ -370,9 +370,9 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
- HloModule module(TestName());
- module.AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ auto module = CreateNewModule();
+ module->AddEntryComputation(MakeScalarComputation());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
@@ -385,3 +385,7 @@ TEST_F(CallGraphTest, VisitWithError) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index cb9682392e..0a990dc13f 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -74,13 +74,13 @@ TEST_F(CopyInsertionTest, SingleParameter) {
EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(old_root->operand(0))));
}
@@ -93,13 +93,13 @@ TEST_F(CopyInsertionTest, SingleConstant) {
EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(old_root->operand(0))));
}
@@ -124,13 +124,13 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(old_root->operand(0)),
op::Copy(old_root->operand(1)), old_root->operand(2)));
}
@@ -160,13 +160,13 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(op::GetTupleElement(old_root)),
op::Copy(op::GetTupleElement(old_root))));
}
@@ -180,15 +180,15 @@ TEST_F(CopyInsertionTest, BitcastParameter) {
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
}
@@ -202,15 +202,15 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
}
@@ -223,15 +223,15 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(old_root->operand(0))));
}
@@ -250,15 +250,15 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) {
ShapeUtil::MakeShape(F32, {42})}),
"param0"));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(HloOpcode::kParameter,
- module.entry_computation()->root_instruction()->opcode());
+ module->entry_computation()->root_instruction()->opcode());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
- HloInstruction* new_root = module.entry_computation()->root_instruction();
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
+ HloInstruction* new_root = module->entry_computation()->root_instruction();
EXPECT_NE(old_root, new_root);
EXPECT_THAT(
@@ -289,15 +289,15 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- EXPECT_EQ(gte, module.entry_computation()->root_instruction());
+ EXPECT_EQ(gte, module->entry_computation()->root_instruction());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(op::GetTupleElement(old_root)),
op::Copy(op::GetTupleElement(old_root))));
}
@@ -325,21 +325,21 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
- EXPECT_EQ(gte, module.entry_computation()->root_instruction());
+ EXPECT_EQ(gte, module->entry_computation()->root_instruction());
- HloInstruction* old_root = module.entry_computation()->root_instruction();
- InsertCopies(&module);
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
+ InsertCopies(module.get());
- EXPECT_THAT(module.entry_computation()->root_instruction(),
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
}
class WhileCopyInsertionTest : public CopyInsertionTest {
protected:
- WhileCopyInsertionTest() : module_(TestName()) {}
+ WhileCopyInsertionTest() : module_(CreateNewModule()) {}
// Builds a While condition computation which reads the induction variable
// from the tuple parameter, and returns a predicate indicating whether this
@@ -587,7 +587,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction::CreateTuple({induction_var_init, inner_init}));
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition, body, loop_state_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
return while_hlo;
}
@@ -595,7 +595,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction::CreateTuple({induction_var_init, data_init}));
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition, body, loop_state_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
return while_hlo;
}
@@ -679,18 +679,18 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto induction_var_init = builder->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(nested));
- auto body =
- module_.AddEmbeddedComputation(BuildIndependentBodyComputation(nested));
+ module_->AddEmbeddedComputation(BuildConditionComputation(nested));
+ auto body = module_->AddEmbeddedComputation(
+ BuildIndependentBodyComputation(nested));
auto loop_state_init = builder->AddInstruction(
HloInstruction::CreateTuple({induction_var_init, data_init}));
auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition, body, loop_state_init));
- module_.AddEntryComputation(builder->Build());
+ module_->AddEntryComputation(builder->Build());
return while_hlo;
}
- HloModule module_;
+ std::unique_ptr<HloModule> module_;
Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
Shape loop_state_shape_ =
@@ -712,13 +712,14 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// CopyInsertion pass should not generate any copies.
//
TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
- auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto body = module_.AddEmbeddedComputation(BuildIndependentBodyComputation());
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto body =
+ module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
auto while_hlo = BuildWhileInstruction(condition, body);
const HloInstruction* old_init = while_hlo->operand(0);
HloInstruction* old_root = body->root_instruction();
- InsertCopies(&module_);
+ InsertCopies(module_.get());
HloInstruction* new_root = body->root_instruction();
const HloInstruction* new_init = while_hlo->operand(0);
@@ -742,13 +743,13 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
// Tuple(Copy(out0), out1)
//
TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
- auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto body = module_.AddEmbeddedComputation(BuildDependentBodyComputation());
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
auto while_hlo = BuildWhileInstruction(condition, body);
const HloInstruction* old_init = while_hlo->operand(0);
HloInstruction* old_root = body->root_instruction();
- InsertCopies(&module_);
+ InsertCopies(module_.get());
HloInstruction* new_root = body->root_instruction();
const HloInstruction* new_init = while_hlo->operand(0);
@@ -773,14 +774,14 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
//
// CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
- auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto body = module_.AddEmbeddedComputation(
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto body = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
auto while_hlo = BuildWhileInstruction(condition, body);
const HloInstruction* old_init = while_hlo->operand(0);
HloInstruction* old_root = body->root_instruction();
- InsertCopies(&module_);
+ InsertCopies(module_.get());
HloInstruction* new_root = body->root_instruction();
const HloInstruction* new_init = while_hlo->operand(0);
@@ -796,11 +797,13 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
// Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,
DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
- auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto body1 = module_.AddEmbeddedComputation(
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto body1 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
- auto body2 = module_.AddEmbeddedComputation(
+ auto body2 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
auto builder = HloComputation::Builder(TestName() + ".While");
@@ -815,9 +818,9 @@ TEST_F(WhileCopyInsertionTest,
loop_state_shape_, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition2, body2, loop_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
- InsertCopies(&module_);
+ InsertCopies(module_.get());
// Both while loops share a single copy of iter_param, since index 0 is
// read-only in the body.
@@ -836,11 +839,13 @@ TEST_F(WhileCopyInsertionTest,
// Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,
DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
- auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto body1 = module_.AddEmbeddedComputation(
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto body1 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
- auto body2 = module_.AddEmbeddedComputation(
+ auto body2 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
auto builder = HloComputation::Builder(TestName() + ".While");
@@ -860,9 +865,9 @@ TEST_F(WhileCopyInsertionTest,
loop_state_shape_, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition2, body2, loop_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
- InsertCopies(&module_);
+ InsertCopies(module_.get());
// No copies of iter_value are necessary, since index 0 is read-only in both
// while bodies.
@@ -908,12 +913,12 @@ TEST_F(WhileCopyInsertionTest,
//
TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(true));
- auto body = module_.AddEmbeddedComputation(BuildNestedBodyComputation());
+ module_->AddEmbeddedComputation(BuildConditionComputation(true));
+ auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
BuildWhileInstruction(condition, body, true);
HloInstruction* old_root = body->root_instruction();
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(body->root_instruction(),
op::Tuple(old_root->operand(0),
@@ -930,7 +935,7 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
auto old_init = while_hlo->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
op::Copy(old_init->operand(1))));
@@ -945,7 +950,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
auto old_init = while_hlo->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
op::Copy(old_init->operand(1))));
@@ -978,7 +983,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
auto old_init = while_hlo->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(
while_hlo->operand(0),
@@ -1014,7 +1019,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
auto old_init = while_hlo->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(while_hlo->operand(0),
op::Tuple(op::Copy(old_init->operand(0)),
@@ -1034,7 +1039,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
auto old_init = while_hlo->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
op::Copy(old_init->operand(1))));
@@ -1052,12 +1057,16 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
// (non-identical Copys). In other words, verifies that copy sharing does not
// insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
- auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
- auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
// Loop body that outputs tuple comprises two elements dependent on the init
// tuple.
- auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
- auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
+ auto body1 =
+ module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
+ auto body2 =
+ module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
auto builder = HloComputation::Builder(TestName() + ".While");
@@ -1079,10 +1088,10 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition2, body2, loop_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
auto points_to_analysis =
- TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
+ TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Asserts that the init tuples before copy insertion is non-distinct.
ASSERT_FALSE(
@@ -1093,7 +1102,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto old_init1 = while_hlo1->operand(0);
auto old_init2 = while_hlo2->operand(0);
- InsertCopies(&module_);
+ InsertCopies(module_.get());
EXPECT_THAT(while_hlo1->operand(0),
op::Tuple(op::Copy(old_init1->operand(0)),
@@ -1106,7 +1115,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
op::Copy(old_init2->operand(2))));
// Verifies the init tuples after copy insertion is distinct.
- points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
+ points_to_analysis =
+ TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
const auto& points_to1 =
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
EXPECT_TRUE(points_to1.IsDistinct());
@@ -1118,3 +1128,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 19180dd243..8045d77a40 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -513,7 +513,6 @@ cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index b42702dbe1..f5ad431277 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -81,7 +81,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
input, kernel, conv_window_, dnums));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
@@ -135,7 +135,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
input, kernel, conv_window_, dnums));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
ConvCanonicalization conv_canonicalization;
@@ -144,3 +144,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
} // namespace cpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 4e03a96fb3..bb4712c86f 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -108,12 +108,12 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
// c
//
// Calls are made via kCall, kWhile, and kMap instructions.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* cond_computation =
- module.AddEmbeddedComputation(MakeConditionComputation());
+ module->AddEmbeddedComputation(MakeConditionComputation());
HloComputation* c_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
- HloComputation* b_computation = module.AddEmbeddedComputation(
+ module->AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* b_computation = module->AddEmbeddedComputation(
MakeMappingComputation(c_computation, /*callsites=*/1));
HloComputation* a_computation;
@@ -125,7 +125,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, b_computation, call));
- a_computation = module.AddEmbeddedComputation(builder.Build());
+ a_computation = module->AddEmbeddedComputation(builder.Build());
}
HloComputation* entry_computation;
@@ -135,13 +135,13 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, a_computation, param0));
- entry_computation = module.AddEntryComputation(builder.Build());
+ entry_computation = module->AddEntryComputation(builder.Build());
}
{
- TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
+ TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get()));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -149,7 +149,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
// Test corner case of a computation used as a body and a loop condition.
TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* cond_computation;
{
HloComputation::Builder builder(TestName() + ".cond");
@@ -161,7 +161,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kEq, param0, false_constant));
- cond_computation = module.AddEmbeddedComputation(builder.Build());
+ cond_computation = module->AddEmbeddedComputation(builder.Build());
}
HloComputation* entry_computation;
@@ -172,19 +172,19 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
builder.AddInstruction(HloInstruction::CreateWhile(
ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
false_constant));
- entry_computation = module.AddEntryComputation(builder.Build());
+ entry_computation = module->AddEntryComputation(builder.Build());
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
+ TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get()));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -201,20 +201,20 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
// C
//
TEST_F(FlattenCallGraphTest, FlattenCalls) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* c_computation =
- module.AddEmbeddedComputation(MakeScalarComputation());
+ module->AddEmbeddedComputation(MakeScalarComputation());
- HloComputation* b_computation = module.AddEmbeddedComputation(
+ HloComputation* b_computation = module->AddEmbeddedComputation(
MakeCallingComputation(c_computation, /*callsites=*/2, ".B"));
- module.AddEntryComputation(
+ module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
+ TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get()));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
- EXPECT_EQ(7, module.computations().size());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ EXPECT_EQ(7, module->computations().size());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
@@ -225,3 +225,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 1254f9cdeb..3fcdb086af 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -86,7 +86,6 @@ cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
],
)
@@ -313,7 +312,6 @@ cc_test(
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -335,8 +333,6 @@ cc_test(
deps = [
":instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
],
)
@@ -378,7 +374,6 @@ cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -488,7 +483,6 @@ cc_test(
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -518,7 +512,6 @@ cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
@@ -547,7 +540,6 @@ cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
index 83922cbe14..ba9c70ded3 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
@@ -97,10 +97,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) {
activations, gradients, conv_window,
tf_default_dnums_for_backward_filter_));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
@@ -126,9 +126,9 @@ TEST_F(ConvolutionFoldingTest,
activations, gradients, conv_window,
tf_default_dnums_for_backward_filter_));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConvolution(&module));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(module.get()));
}
// Extracted from block35 training.
@@ -155,10 +155,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) {
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0}));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
@@ -189,10 +189,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) {
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0}));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
@@ -222,10 +222,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) {
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0}));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
@@ -269,10 +269,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
.ValueOrDie()));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
@@ -313,10 +313,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) {
/*lhs=*/output, /*rhs=*/kernel, conv_window,
tf_default_dnums_for_backward_input_));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
@@ -346,9 +346,9 @@ TEST_F(ConvolutionFoldingTest,
/*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
tf_default_dnums_for_backward_input_));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConvolution(&module));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(module.get()));
}
// Extracted from Inception V3 training.
@@ -394,10 +394,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) {
tf_default_dnums_for_backward_input_)
.ValueOrDie()));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
EXPECT_EQ(HloOpcode::kFusion,
entry_computation->root_instruction()->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
@@ -441,9 +441,9 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) {
tf_default_dnums_for_backward_input_)
.ValueOrDie()));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConvolution(&module));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(module.get()));
}
// Extracted from //learning/brain/google/xla/benchmarks/resnet.py
@@ -490,10 +490,10 @@ TEST_F(ConvolutionFoldingTest,
tf_default_dnums_for_backward_input_)
.ValueOrDie()));
- HloModule module(TestName());
+ auto module = CreateNewModule();
const HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConvolution(&module));
+ module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(module.get()));
const HloInstruction* backward_conv = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode());
EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
@@ -543,10 +543,14 @@ TEST_F(ConvolutionFoldingTest,
tf_default_dnums_for_backward_input_)
.ValueOrDie()));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConvolution(&module));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(module.get()));
}
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index a87e66ca86..8afc32dea9 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -25,7 +25,7 @@ namespace {
class FusionMergerTest : public HloTestBase {
protected:
- FusionMergerTest() : module_(TestName()) {}
+ FusionMergerTest() : module_(CreateNewModule()) {}
// Builds the following computation:
//
@@ -86,7 +86,7 @@ class FusionMergerTest : public HloTestBase {
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({out0, out1, out2}));
- return module_.AddEntryComputation(builder.Build());
+ return module_->AddEntryComputation(builder.Build());
}
// Builds the following computation:
@@ -154,7 +154,7 @@ class FusionMergerTest : public HloTestBase {
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({out0, out1}));
- return module_.AddEntryComputation(builder.Build());
+ return module_->AddEntryComputation(builder.Build());
}
// Builds the following computation:
@@ -225,7 +225,7 @@ class FusionMergerTest : public HloTestBase {
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({out0, out1}));
- return module_.AddEntryComputation(builder.Build());
+ return module_->AddEntryComputation(builder.Build());
}
Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
@@ -235,7 +235,7 @@ class FusionMergerTest : public HloTestBase {
Shape tuple_shape4_ = ShapeUtil::MakeTupleShape(
{data_shape_, data_shape_, data_shape_, data_shape_});
- HloModule module_;
+ std::unique_ptr<HloModule> module_;
};
// Tests that we can merge a fusion instruction that is below threshold.
@@ -278,13 +278,15 @@ class FusionMergerTest : public HloTestBase {
TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
auto computation = BuildComputation0();
// Run standard fusion passes.
- EXPECT_TRUE(
- GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie());
- EXPECT_FALSE(
- GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false)
+ .Run(module_.get())
+ .ValueOrDie());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module_.get())
+ .ValueOrDie());
// Run fusion merger pass, which should merge the shared fusion instruction
// into its two users.
- EXPECT_TRUE(FusionMerger().Run(&module_).ValueOrDie());
+ EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie());
auto* root = computation->root_instruction();
EXPECT_EQ(HloOpcode::kTuple, root->opcode());
@@ -338,14 +340,16 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) {
BuildComputation1();
// Run standard fusion passes.
- EXPECT_TRUE(
- GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie());
- EXPECT_FALSE(
- GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false)
+ .Run(module_.get())
+ .ValueOrDie());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module_.get())
+ .ValueOrDie());
// Run fusion merger pass, which should detect that the flops/bytes of the
// shared fusion instruction exceeds the threshold ratio, and therefore
// cannot be merged with other fusion instructions.
- EXPECT_FALSE(FusionMerger().Run(&module_).ValueOrDie());
+ EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie());
}
// Tests that threshold for bytes transferred if merged is exceeded.
@@ -388,13 +392,15 @@ TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) {
TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) {
BuildComputation2(/*add_extra_input=*/true);
// Run standard fusion passes.
- EXPECT_TRUE(
- GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie());
- EXPECT_FALSE(
- GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false)
+ .Run(module_.get())
+ .ValueOrDie());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module_.get())
+ .ValueOrDie());
// Run fusion merger pass, which should detect that the net bytes transferred
// (if merged) would increase.
- EXPECT_FALSE(FusionMerger().Run(&module_).ValueOrDie());
+ EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie());
}
// Tests that threshold for bytes transferred if merged is not exceeded.
@@ -442,15 +448,21 @@ TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) {
TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) {
BuildComputation2(/*add_extra_input=*/false);
// Run standard fusion passes.
- EXPECT_TRUE(
- GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie());
- EXPECT_FALSE(
- GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false)
+ .Run(module_.get())
+ .ValueOrDie());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module_.get())
+ .ValueOrDie());
// Run fusion merger pass, which should detect that the net bytes transferred
// (if merged) would not increase.
- EXPECT_TRUE(FusionMerger().Run(&module_).ValueOrDie());
+ EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie());
}
} // namespace
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index dc421695cb..118ef18c44 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -69,14 +69,14 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
HloInstruction* dot2 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(dot2));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(dot2));
- std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(module, *streams);
+ auto schedule = BuildHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({dot1, dot2}));
@@ -129,16 +129,16 @@ TEST_F(HloScheduleTest, SequentialAdd) {
HloInstruction* add3 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(add3));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(add3));
- std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
EXPECT_EQ(streams->StreamNumberForHlo(*add1),
streams->StreamNumberForHlo(*add2));
EXPECT_EQ(streams->StreamNumberForHlo(*add1),
streams->StreamNumberForHlo(*add3));
- auto schedule = BuildHloSchedule(module, *streams);
+ auto schedule = BuildHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({add1, add2, add3}));
@@ -199,14 +199,14 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(add));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(add));
- std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
EXPECT_NE(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(module, *streams);
+ auto schedule = BuildHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y});
EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) ||
@@ -278,10 +278,10 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
HloInstruction* d40 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(d40));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(d40));
- std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
// The two dots on layer 1 are concurrent.
EXPECT_NE(streams->StreamNumberForHlo(*d10),
streams->StreamNumberForHlo(*d11));
@@ -298,7 +298,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
// We don't check the thunk launch order, since there are many valid total
// orders, and it's annoying to express.
- auto schedule = BuildHloSchedule(module, *streams);
+ auto schedule = BuildHloSchedule(*module, *streams);
auto order = schedule->ConsumeHloOrdering();
const HloVec all_params(
@@ -393,3 +393,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index c58af04bad..896f6ea842 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
namespace gpu {
@@ -32,7 +31,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
@@ -49,7 +48,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
@@ -89,7 +88,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
@@ -108,7 +107,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
HloInstruction::CreateGetTupleElement(data_shape, param, 1));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
@@ -124,3 +123,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc
index 692ec8147d..fa258b6e56 100644
--- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc
@@ -55,9 +55,9 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
HloInstruction::CreateParameter(1, ashape, "y"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(builder.Build(add));
+ module->AddEntryComputation(builder.Build(add));
ComputationLayout computation_layout(
computation->ComputeProgramShape());
@@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
ShapeLayout(result_shape_with_layout);
GpuLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_TRUE(layout_assignment.Run(&module).ValueOrDie());
+ EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
@@ -83,3 +83,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
} // namespace
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 56e3ff99fa..a5230b3e8e 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -45,10 +45,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
HloInstruction* dot2 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(dot2));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(dot2));
- std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
EXPECT_EQ(assignment->StreamNumberForHlo(*dot1),
assignment->StreamNumberForHlo(*dot2));
}
@@ -66,10 +66,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(add));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(add));
- std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
EXPECT_NE(assignment->StreamNumberForHlo(*dot1),
assignment->StreamNumberForHlo(*dot2));
}
@@ -110,10 +110,10 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
HloInstruction* d40 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build(d40));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build(d40));
- std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
// The two dots on layer 1 are concurrent.
EXPECT_NE(assignment->StreamNumberForHlo(*d10),
assignment->StreamNumberForHlo(*d11));
@@ -131,3 +131,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index a315b9ad11..e82491fd6f 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -30,7 +30,7 @@ using ::testing::HasSubstr;
class WhileTransformerTest : public HloTestBase {
protected:
WhileTransformerTest()
- : module_(TestName()),
+ : module_(CreateNewModule()),
induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
data_shape_(ShapeUtil::MakeShape(F32, {8})),
loop_state_shape_(ShapeUtil::MakeTupleShape(
@@ -102,26 +102,26 @@ class WhileTransformerTest : public HloTestBase {
HloInstruction::CreateTuple({data_init, induction_var_init}));
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition, body, loop_state_init));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
return while_hlo;
}
void RunFusionPasses() {
// Run standard fusion passes.
EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false)
- .Run(&module_)
+ .Run(module_.get())
.ValueOrDie());
EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true)
- .Run(&module_)
+ .Run(module_.get())
.ValueOrDie());
}
void RunCopyInsertionPass() {
CopyInsertion copy_insertion;
- EXPECT_IS_OK(copy_insertion.Run(&module_).status());
+ EXPECT_IS_OK(copy_insertion.Run(module_.get()).status());
}
- HloModule module_;
+ std::unique_ptr<HloModule> module_;
Shape induction_variable_shape_;
Shape data_shape_;
Shape loop_state_shape_;
@@ -131,8 +131,8 @@ class WhileTransformerTest : public HloTestBase {
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
// Build computation with induction variable at tuple element 0.
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
- auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
+ module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
// Run HLO Optimization passes.
RunFusionPasses();
@@ -148,8 +148,8 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
// Build computation with induction variable at tuple element 1.
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(1, 10));
- auto body = module_.AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
+ module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
+ auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
auto while_hlo = BuildWhileInstruction(condition, body, 1, 0);
// Run HLO Optimization passes.
RunFusionPasses();
@@ -165,8 +165,8 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
TEST_F(WhileTransformerTest, InvalidLoopLimit) {
// Build computation with invalid loop limit.
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(0, 5));
- auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
+ module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
+ auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
auto while_hlo = BuildWhileInstruction(condition, body, 0, 10);
// Run HLO Optimization passes.
RunFusionPasses();
@@ -181,8 +181,8 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) {
TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
// Build computation with invalid loop increment.
auto condition =
- module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
- auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
+ module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
// Run HLO Optimization passes.
RunFusionPasses();
@@ -196,3 +196,7 @@ TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 3812653fe3..5d49c83e2d 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -355,3 +355,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index a56225da15..04ab02995b 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -45,7 +45,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
@@ -67,7 +67,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
@@ -89,7 +89,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
@@ -135,7 +135,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(HloInstruction::CreateSlice(
shape, literal_instruction, slice_start, slice_limits));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
@@ -211,3 +211,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 9444382b52..cc39c3ac20 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -57,7 +57,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, computation->instruction_count());
@@ -87,7 +87,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, computation->instruction_count());
@@ -119,7 +119,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, computation->instruction_count());
@@ -156,13 +156,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(7, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(6, computation->instruction_count());
}
@@ -184,15 +184,15 @@ TEST_F(HloCseTest, NonscalarConstants) {
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
{common_constant1, common_constant2, uncommon_constant}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple,
op::Tuple(common_constant1, common_constant2, uncommon_constant));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -216,14 +216,14 @@ TEST_F(HloCseTest, IdenticalInstructions) {
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(5, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -249,14 +249,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(&module).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
@@ -280,14 +280,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -330,14 +330,14 @@ TEST_F(HloCseTest, IdenticalExpressions) {
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(8, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
auto operand = tuple->operand(0);
@@ -362,7 +362,7 @@ TEST_F(HloCseTest, DoNotCombineRng) {
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, rng1, rng2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -384,7 +384,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
// Test that two calls to an impure function are not commoned. RNG
// is the source of the impurity.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
// rng_function is an impure function because it does RNG.
HloComputation* rng_function = nullptr;
@@ -435,3 +435,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 4191eaaad0..10cd7ca7c0 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -51,7 +51,7 @@ TEST_F(HloDceTest, NoDeadCode) {
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, computation->instruction_count());
@@ -81,7 +81,7 @@ TEST_F(HloDceTest, DeadParameters) {
builder.AddInstruction(HloInstruction::CreateUnary(
live_param->shape(), HloOpcode::kNegate, live_param));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(5, computation->instruction_count());
@@ -121,7 +121,7 @@ TEST_F(HloDceTest, ControlDependencies) {
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
// Add a control dependency between two instructions.
@@ -156,3 +156,7 @@ TEST_F(HloDceTest, ControlDependencies) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index a226ab0d0c..bcf81cd8dd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1003,3 +1003,7 @@ TEST_F(HloInstructionTest, CloneSuffixNames) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 1175be4f50..870bc729ae 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -58,7 +58,7 @@ class HloModuleTest : public HloTestBase {
TEST_F(HloModuleTest, OneComputationPostOrder) {
// Create a module with a single computation.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(CreateConstantComputation());
EXPECT_THAT(module->MakeComputationPostOrder(),
@@ -67,7 +67,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) {
TEST_F(HloModuleTest, TwoComputationsPostOrder) {
// Create a module with two unconnected computations.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation1 = module->AddEntryComputation(CreateConstantComputation());
auto computation2 =
module->AddEmbeddedComputation(CreateConstantComputation());
@@ -83,7 +83,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) {
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
// Create a module with a diamond call graph of computations.
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation1 =
module->AddEmbeddedComputation(CreateConstantComputation());
auto computation2 =
@@ -104,3 +104,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index c387fbb89b..21d852a51d 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -58,23 +58,23 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) {
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
TF_ASSIGN_OR_ASSERT_OK(
SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(module, [](const LogicalBuffer& buffer) {
+ CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module.entry_computation()->instruction_count(),
- sequence.at(module.entry_computation()).size());
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ sequence.at(module->entry_computation()).size());
// The first instruction should be the parameter and the last the root "sub".
- EXPECT_EQ(param, sequence.at(module.entry_computation()).front());
- EXPECT_EQ(sub, sequence.at(module.entry_computation()).back());
+ EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
+ EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
- SequentialHloOrdering ordering(&module, sequence);
+ SequentialHloOrdering ordering(module.get(), sequence);
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
}
@@ -96,14 +96,14 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
// %c = Constant(42.0f)
//
// This results in a diamond-shaped callgraph.
- HloModule module(TestName());
+ auto module = CreateNewModule();
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
auto builder_c = HloComputation::Builder("C");
HloInstruction* c = builder_c.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
HloComputation* computation_c =
- module.AddEmbeddedComputation(builder_c.Build());
+ module->AddEmbeddedComputation(builder_c.Build());
auto builder_b = HloComputation::Builder("B");
builder_b.AddInstruction(
@@ -111,22 +111,22 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
HloInstruction* b = builder_b.AddInstruction(
HloInstruction::CreateCall(scalar_shape, {}, computation_c));
HloComputation* computation_b =
- module.AddEmbeddedComputation(builder_b.Build());
+ module->AddEmbeddedComputation(builder_b.Build());
auto builder_a = HloComputation::Builder("A");
HloInstruction* a = builder_a.AddInstruction(
HloInstruction::CreateCall(scalar_shape, {}, computation_c));
HloComputation* computation_a =
- module.AddEmbeddedComputation(builder_a.Build());
+ module->AddEmbeddedComputation(builder_a.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* x = builder.AddInstruction(
HloInstruction::CreateCall(scalar_shape, {}, computation_a));
HloInstruction* y = builder.AddInstruction(
HloInstruction::CreateCall(scalar_shape, {x}, computation_b));
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- DependencyHloOrdering ordering(&module);
+ DependencyHloOrdering ordering(module.get());
EXPECT_TRUE(ordering.ExecutesBefore(x, y));
EXPECT_FALSE(ordering.ExecutesBefore(y, x));
@@ -158,7 +158,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
class MinimumMemoryForSequenceTest : public HloTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
@@ -176,14 +176,14 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kLt, cond_iter, cond_data));
HloComputation* cond_computation =
- module.AddEmbeddedComputation(cond_builder.Build());
+ module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloComputation* body_computation =
- module.AddEmbeddedComputation(body_builder.Build());
+ module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
// Entry params: 8 bytes (4 bytes per param), TOTAL=8
@@ -199,7 +199,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
auto size_fn = [](const LogicalBuffer& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
@@ -217,3 +217,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 2a1d728bc8..d9c2a5f0ac 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -140,9 +140,9 @@ class HloRematerializationTest : public HloTestBase {
// Test rematerialization of a single computation produced by
// MakeRematerializableComputation.
TEST_F(HloRematerializationTest, SingleComputation) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(MakeRematerializableComputation());
+ module->AddEntryComputation(MakeRematerializableComputation());
// Find and save the original broadcast instruction which should be
// rematerialized.
@@ -155,9 +155,10 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/14 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/14 * 1024, module.get(), &sequence));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -179,17 +180,18 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// MakeRematerializableComputation but with a sufficiently high memory limit
// such that no instructions are rematerialized.
TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(MakeRematerializableComputation());
+ module->AddEntryComputation(MakeRematerializableComputation());
EXPECT_EQ(computation->instruction_count(), 7);
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/20 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/20 * 1024, module.get(), &sequence));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -203,7 +205,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
// computation should be the one chosen because rematerialization in the while
// will presumably be more expensive.
TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
@@ -211,12 +213,12 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
- module.AddEmbeddedComputation(cond_builder.Build());
+ module->AddEmbeddedComputation(cond_builder.Build());
- HloComputation* body_computation = module.AddEmbeddedComputation(
+ HloComputation* body_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".body"));
HloComputation* entry_computation =
- module.AddEntryComputation(MakeRematerializableWhileComputation(
+ module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
EXPECT_EQ(entry_computation->instruction_count(), 6);
@@ -227,9 +229,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// bit lower (17KB) to force rematerialization of the entry computation.
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/17 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/17 * 1024, module.get(), &sequence));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -241,7 +244,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// while. Both the entry computation and while body computation should have
// computations rematerialized.
TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
@@ -249,12 +252,12 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
- module.AddEmbeddedComputation(cond_builder.Build());
+ module->AddEmbeddedComputation(cond_builder.Build());
- HloComputation* body_computation = module.AddEmbeddedComputation(
+ HloComputation* body_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".body"));
HloComputation* entry_computation =
- module.AddEntryComputation(MakeRematerializableWhileComputation(
+ module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
EXPECT_EQ(entry_computation->instruction_count(), 6);
@@ -262,9 +265,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/15 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/15 * 1024, module.get(), &sequence));
EXPECT_TRUE(changed);
// Both computations should have a rematerialized instruction added.
@@ -275,7 +279,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
// Test rematerialization of a doubly nested computation. All computations
// should have an instruction rematerialized.
TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
@@ -283,16 +287,16 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
- module.AddEmbeddedComputation(cond_builder.Build());
+ module->AddEmbeddedComputation(cond_builder.Build());
- HloComputation* inner_computation = module.AddEmbeddedComputation(
+ HloComputation* inner_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".inner"));
HloComputation* middle_computation =
- module.AddEmbeddedComputation(MakeRematerializableWhileComputation(
+ module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/inner_computation,
/*suffix=*/".middle"));
HloComputation* entry_computation =
- module.AddEntryComputation(MakeRematerializableWhileComputation(
+ module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/middle_computation));
EXPECT_EQ(entry_computation->instruction_count(), 6);
@@ -303,9 +307,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// ~12K so pick something slightly larger.
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/13 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/13 * 1024, module.get(), &sequence));
EXPECT_TRUE(changed);
// All computations should have a rematerialized instruction added.
@@ -336,7 +341,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// The value %bcast is live across each call of Subcomputation (which requires
// 8KB) though the value is not used in the calls. Rematerializing %bcast
// across these calls reduces peak memory use from ~20KB down to ~16KB.
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* subcomputation = nullptr;
{
@@ -349,7 +354,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024}));
- subcomputation = module.AddEmbeddedComputation(builder.Build());
+ subcomputation = module->AddEmbeddedComputation(builder.Build());
}
auto builder = HloComputation::Builder(TestName());
@@ -372,7 +377,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
auto count_broadcasts = [](const HloComputation* computation) {
int64 bcast_count = 0;
@@ -398,9 +403,10 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -451,7 +457,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// (ie %bcast is used indirectly by %negate), otherwise the %negate operand
// aliases %add_2.
const bool indirectly_used = GetParam();
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* subcomputation = nullptr;
{
@@ -464,7 +470,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024}));
- subcomputation = module.AddEmbeddedComputation(builder.Build());
+ subcomputation = module->AddEmbeddedComputation(builder.Build());
}
auto builder = HloComputation::Builder(TestName());
@@ -485,7 +491,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(entry_computation->instruction_count(), 8);
@@ -494,9 +500,10 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSIGN_OR_ASSERT_OK(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, &module, &sequence));
+ bool changed,
+ HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf,
+ /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
@@ -514,3 +521,7 @@ INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest,
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
index 14800b5342..867ebc7f61 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
@@ -66,13 +66,13 @@ class HloSubcomputationUnificationTest : public HloTestBase {
};
TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto callee1 =
- hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
+ module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
auto callee2 =
- hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
+ module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
@@ -83,32 +83,31 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
- hlo_module->AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_EQ(3, module->computations().size());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"before unification", false, false, nullptr);
}
- EXPECT_TRUE(
- HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"after unification", false, false, nullptr);
}
- EXPECT_EQ(2, hlo_module->computations().size());
+ EXPECT_EQ(2, module->computations().size());
EXPECT_EQ(x->to_apply(), y->to_apply());
}
TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto callee1 =
- hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
+ module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
auto callee2 =
- hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
+ module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
@@ -121,33 +120,32 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
- hlo_module->AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_EQ(3, module->computations().size());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"before unification", false, false, nullptr);
}
- EXPECT_TRUE(
- HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"after unification", false, false, nullptr);
}
- EXPECT_EQ(2, hlo_module->computations().size());
+ EXPECT_EQ(2, module->computations().size());
EXPECT_EQ(x->to_apply(), y->to_apply());
}
// Do not unify subcomputations with different parameter shapes.
TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto callee1 = hlo_module->AddEmbeddedComputation(
- CreateR1S32AdditionComputation(r1s32_5_));
- auto callee2 = hlo_module->AddEmbeddedComputation(
- CreateR1S32AdditionComputation(r1s32_3_));
+ auto callee1 =
+ module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_));
+ auto callee2 =
+ module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1s32_5_, "param1"));
@@ -160,28 +158,27 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(S32, {8}), {x, y}, 0));
- hlo_module->AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_EQ(3, module->computations().size());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"before unification", false, false, nullptr);
}
- EXPECT_FALSE(
- HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
if (VLOG_IS_ON(1)) {
- hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
"after unification", false, false, nullptr);
}
- EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_EQ(3, module->computations().size());
EXPECT_NE(x->to_apply(), y->to_apply());
}
// Regression test for b/31466798. Checks that entry_computation is still valid
// after unification.
TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
for (int i = 0; i < 2; ++i) {
HloComputation::Builder builder("pow");
auto x =
@@ -191,15 +188,19 @@ TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y));
if (i == 0) {
- module.AddEmbeddedComputation(builder.Build());
+ module->AddEmbeddedComputation(builder.Build());
} else {
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
}
}
- EXPECT_TRUE(HloSubcomputationUnification().Run(&module).ValueOrDie());
- EXPECT_EQ(1, module.computations().size());
- EXPECT_EQ(module.computations().front().get(), module.entry_computation());
+ EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
+ EXPECT_EQ(1, module->computations().size());
+ EXPECT_EQ(module->computations().front().get(), module->entry_computation());
}
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 6041debc4a..c2718ea800 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -182,3 +182,7 @@ TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
} // namespace
} // namespace hlo_graph_dumper
} // namespace xla
+
+int main(int argc, char **argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index a8d4ecf261..2887a8a0a0 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -59,7 +59,7 @@ TEST_F(InlinerTest, MapMax) {
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
auto computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto hlo_module = CreateNewModule();
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
@@ -93,7 +93,7 @@ TEST_F(InlinerTest, MapConstant) {
HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
auto computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto hlo_module = CreateNewModule();
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
@@ -110,3 +110,7 @@ TEST_F(InlinerTest, MapConstant) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index d2df0b699e..a2e6c2ae00 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -35,7 +35,7 @@ TEST_F(InstructionFusionTest,
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {1}), exp1, {0}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
@@ -56,7 +56,7 @@ TEST_F(InstructionFusionTest,
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {1}), negate1, {0}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
@@ -76,7 +76,7 @@ TEST_F(InstructionFusionTest,
HloInstruction* reshape2 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_TRUE(
@@ -96,7 +96,7 @@ TEST_F(InstructionFusionTest,
HloInstruction* transpose2 = builder.AddInstruction(
HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_TRUE(
@@ -113,7 +113,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
@@ -129,7 +129,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
@@ -145,7 +145,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {}), param0, {}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose1, computation->root_instruction());
EXPECT_FALSE(
@@ -167,7 +167,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_FALSE(
@@ -187,7 +187,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_TRUE(
@@ -210,7 +210,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_TRUE(
@@ -220,3 +220,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
}
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index bfb9e4ac2e..6d818cdea0 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -69,8 +69,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
HloInstruction::CreateParameter(1, ashape, "param1"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
Layout layout = LayoutUtil::MakeLayout(minor_to_major);
Shape shape(ashape);
@@ -81,7 +81,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -99,8 +99,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
HloInstruction::CreateParameter(1, ashape, "param1"));
builder.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
Shape col_major_shape(ashape);
@@ -117,7 +117,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -148,8 +148,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
auto negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
auto fusion = computation->CreateFusionInstruction(
{negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
@@ -162,7 +162,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -197,13 +197,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
constant0->shape(), HloOpcode::kNegate, get_element0));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
- module.entry_computation()->ComputeProgramShape());
+ module->entry_computation()->ComputeProgramShape());
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_FALSE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -235,17 +235,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
- module.entry_computation()->ComputeProgramShape());
+ module->entry_computation()->ComputeProgramShape());
Shape result_shape =
ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -270,11 +270,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
auto nested_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
- module.entry_computation()->ComputeProgramShape());
+ module->entry_computation()->ComputeProgramShape());
Shape result_shape = nested_tuple->shape();
*ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
@@ -284,7 +284,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -300,9 +300,9 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(&module)
+ .Run(module.get())
.ValueOrDie());
- HloInstruction* root = module.entry_computation()->root_instruction();
+ HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
@@ -329,8 +329,9 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build(tanh));
+ auto module = CreateNewModule();
+ HloComputation* computation =
+ module->AddEntryComputation(builder.Build(tanh));
Shape ashape_with_layout(ashape);
Shape bshape_with_layout(bshape);
@@ -341,7 +342,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -370,8 +371,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
HloInstruction::CreateTranspose(bshape, log, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build(tanh));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build(tanh));
Shape ashape_with_layout(ashape);
Shape bshape_with_layout(bshape);
@@ -382,7 +383,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -406,9 +407,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(builder.Build(transpose));
+ module->AddEntryComputation(builder.Build(transpose));
Shape input_shape_with_layout(ashape);
Shape output_shape_with_layout(cshape);
@@ -421,7 +422,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -455,9 +456,9 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(builder.Build(tuple));
+ module->AddEntryComputation(builder.Build(tuple));
ComputationLayout computation_layout(computation->ComputeProgramShape());
Shape param_shape_with_layout(f32_4);
@@ -474,7 +475,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(&module, &computation_layout);
+ AssignLayouts(module.get(), &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -525,9 +526,9 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
auto reshape = builder.AddInstruction(
HloInstruction::CreateReshape(cshape, concatenate));
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation* computation =
- module.AddEntryComputation(builder.Build(reshape));
+ module->AddEntryComputation(builder.Build(reshape));
Shape param0_shape_with_layout(ashape);
Shape param1_shape_with_layout(ashape);
@@ -540,7 +541,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(&module).status());
+ EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -553,3 +554,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index ac670069b4..bad4be149a 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -27,7 +27,7 @@ namespace {
class PointsToAnalysisTestBase : public HloTestBase {
protected:
void BuildModule(std::unique_ptr<HloComputation> computation) {
- module_ = MakeUnique<HloModule>(TestName());
+ module_ = CreateNewModule();
computation_ = module_->AddEntryComputation(std::move(computation));
}
@@ -344,7 +344,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- module_ = MakeUnique<HloModule>(TestName());
+ module_ = CreateNewModule();
HloComputation* cond_computation =
module_->AddEmbeddedComputation(make_cond());
HloComputation* body_computation =
@@ -366,3 +366,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 7d8b462279..9becdb2bed 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -50,7 +50,7 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -89,7 +89,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, const1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -115,7 +115,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -142,7 +142,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -197,7 +197,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
ShapeUtil::MakeShape(PRED, {2, 3}), HloOpcode::kSelect, const0, reshape1,
reshape2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -234,7 +234,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, param1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -278,7 +278,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
builder.AddInstruction(HloInstruction::CreateTernary(
root_shape, HloOpcode::kSelect, pred, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -324,7 +324,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, const1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -352,7 +352,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto fusion = computation->AddInstruction(HloInstruction::CreateFusion(
add->shape(), HloInstruction::FusionKind::kLoop, add));
@@ -388,7 +388,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) {
builder.AddInstruction(HloInstruction::CreateTernary(
root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(
@@ -418,7 +418,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
root_shape, HloOpcode::kSelect, reshape_pred, param0, param1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Select(op::Reshape(pred), param0, param1));
@@ -470,7 +470,7 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kMultiply, constant, reshape));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
@@ -519,7 +519,7 @@ TEST_F(ReshapeMoverTest, MultiplePasses) {
builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd,
reshape2, reshape3));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(
@@ -536,3 +536,7 @@ TEST_F(ReshapeMoverTest, MultiplePasses) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2a761d70e5..f75487dd74 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -48,7 +48,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
}
void BuildModule(std::unique_ptr<HloComputation> computation) {
- module_.reset(new HloModule(TestName()));
+ module_ = CreateNewModule();
module_->AddEntryComputation(std::move(computation));
}
@@ -764,3 +764,7 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index e60d38d0c6..4a19a4bb06 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -937,7 +937,6 @@ xla_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 16d4282466..f3badca679 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -43,7 +44,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
ShapeUtil::MakeShape(F32, {}), input, {}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -59,7 +60,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -82,7 +83,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -103,7 +104,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -122,7 +123,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -139,7 +140,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -159,7 +160,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -184,7 +185,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -215,7 +216,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -231,7 +232,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
LOG(INFO) << hlo_module->ToString();
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -254,7 +255,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -288,7 +289,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
// Create HLO module, compile, and execute.
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@@ -302,6 +303,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 6b28d9b032..4c2413d0fe 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -45,11 +46,10 @@ class CopyOpTest : public HloTestBase {
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectEqual(literal, *result);
}
@@ -101,11 +101,11 @@ TEST_F(CopyOpTest, CopyParameterScalar) {
auto computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {constant_device_base});
+ ExecuteAndTransfer(std::move(module), {constant_device_base});
LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
}
@@ -123,10 +123,9 @@ TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
error_spec_);
}
@@ -149,10 +148,9 @@ TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
std::unique_ptr<HloComputation> computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
@@ -182,11 +180,10 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
std::unique_ptr<HloComputation> computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
- ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
}
@@ -217,11 +214,10 @@ void CopyOpTest::TestCopyConstantLayoutR4(
std::unique_ptr<HloComputation> computation = builder.Build();
- auto hlo_module = MakeUnique<HloModule>("test_module");
- hlo_module->AddEntryComputation(std::move(computation));
- ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout(permutation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ auto module = CreateNewModule();
+ module->AddEntryComputation(std::move(computation));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
}
@@ -268,6 +264,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index f31b703b00..4b5c4ecdf7 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -62,7 +63,7 @@ class CustomCallTest : public HloTestBase {
};
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
@@ -70,15 +71,14 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
- hlo_module->AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
Array2D<float> array(2, 2);
@@ -92,16 +92,15 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
- hlo_module->AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
- auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto module = CreateNewModule();
auto b = HloComputation::Builder(TestName());
auto input = b.AddInstruction(
@@ -117,10 +116,9 @@ XLA_TEST_F(CustomCallTest,
HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}),
{incremented, incremented_again}, 0));
- hlo_module->AddEntryComputation(b.Build());
+ module->AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(hlo_module), {});
+ std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
}
@@ -131,6 +129,7 @@ XLA_TEST_F(CustomCallTest,
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7bddbfa894..4fa08e589c 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -36,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::gtl::ArraySlice;
@@ -74,7 +74,7 @@ class FusionTest : public HloTestBase {
}
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
@@ -176,7 +176,7 @@ XLA_TEST_F(FusionTest, Test) {
// (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
// {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
@@ -224,7 +224,7 @@ XLA_TEST_F(FusionTest, Parameter) {
// Build a computation and fuse part of it so the fusion instruction has an
// operand parameter.
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -247,7 +247,7 @@ XLA_TEST_F(FusionTest, Parameter) {
XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
@@ -271,7 +271,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
XLA_TEST_F(FusionTest, ReshapeToScalar) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto single_element_array = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
@@ -285,7 +285,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
@@ -300,7 +300,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
auto reshape1 = builder.AddInstruction(
@@ -315,7 +315,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
auto reshape1 = builder.AddInstruction(
@@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
@@ -343,7 +343,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
XLA_TEST_F(FusionTest, Reshape__) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(
@@ -357,7 +357,7 @@ XLA_TEST_F(FusionTest, Reshape__) {
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(
@@ -372,7 +372,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
XLA_TEST_F(FusionTest, Transpose_2by3) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -387,7 +387,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
XLA_TEST_F(FusionTest, Transpose_3by3) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -402,7 +402,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
XLA_TEST_F(FusionTest, Reverse) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
@@ -427,7 +427,7 @@ std::unique_ptr<HloComputation> MakeReduceTestComputation() {
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
@@ -446,7 +446,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
@@ -468,7 +468,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
auto builder = HloComputation::Builder(TestName());
- auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
auto const1 = builder.AddInstruction(
@@ -574,6 +574,7 @@ XLA_TEST_F(FusionTest, Clamp2D) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index fbbb101ce9..5f7b7aa434 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace se = ::perftools::gputools;
@@ -178,4 +179,21 @@ string HloTestBase::TestName() {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
+int ParseDebugOptionsFlagsAndRunTests(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ ::testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 83c877b393..906551b530 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -115,6 +115,11 @@ class HloTestBase : public ::testing::Test {
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
};
+// Convenience function that parses XLA debug options flags from argc/argv,
+// calls InitGoogleTest and then calls and returns RUN_ALL_TESTS. Intended to be
+// invoked from a test main() function.
+int ParseDebugOptionsFlagsAndRunTests(int argc, char** argv);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_