aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-04-28 22:19:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-28 22:21:47 -0700
commitd02745e20c02ba7506a920cc4c8b00415f82ee79 (patch)
tree1caeccbacfa521d13cf790537bfb24c2e7f7f081 /tensorflow/compiler/xla/service
parent3a9c513c3f4303e5194474d804367c1f4831e3ee (diff)
[TF:XLA]
- Require a module config when creating an HloModule. - All tests using HloTestBase create a module using CreateNewModule. PiperOrigin-RevId: 194684585
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc101
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc122
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h1
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc6
14 files changed, 199 insertions, 190 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f39bfb8012..ed0da47681 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1330,6 +1330,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
@@ -2420,6 +2421,7 @@ tf_cc_test(
":hlo_graph_dumper",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 20c549562d..d0c99bf818 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
/*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Slice(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(b.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- if (!simplifier.Run(&module).ValueOrDie()) {
+ if (!simplifier.Run(module.get()).ValueOrDie()) {
return "NO_CHANGE";
}
auto* root = computation->root_instruction();
@@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Minimum(param0, min_value), max_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMinimum, fmax, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, slice);
@@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction* reshape = builder.AddInstruction(
HloInstruction::CreateReshape(reshape_shape, transpose));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reshape);
@@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
@@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
@@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
@@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
@@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
builder.AddInstruction(
HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 513a8785bb..3ec9795a65 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1641,7 +1641,7 @@ static void RunCopyInsertion(HloModule* module) {
}
TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
@@ -1816,7 +1816,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
};
// Build the entry computation as described in the comment above.
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
@@ -1884,7 +1884,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
}
TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
@@ -1929,7 +1929,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
}
TEST_F(BufferAssignmentTest, TwoCalls) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
HloComputation* sub_computation;
{
@@ -1994,7 +1994,7 @@ static bool IsPostOrderTraversal(
}
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
@@ -2073,7 +2073,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
}
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
- auto module = xla::MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index 05017008e2..acf6611486 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation,
// instructions. Sets the computation as the entry to an HLO module and returns
// the module.
std::unique_ptr<HloModule> MakeBigGraph() {
- auto module = MakeUnique<HloModule>("BigGraph");
+ HloModuleConfig config;
+ auto module = MakeUnique<HloModule>("BigGraph", config);
auto builder = HloComputation::Builder("TestBigGraphvizGraph");
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 688a271712..e983fd11d4 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -76,7 +76,8 @@ class HeapSimulatorTracker {
HeapSimulatorTracker(
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
- module_ = MakeUnique<HloModule>(name);
+ HloModuleConfig config;
+ module_ = MakeUnique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@@ -94,7 +95,8 @@ class HeapSimulatorTracker {
}
explicit HeapSimulatorTracker(const string& name) {
- module_ = MakeUnique<HloModule>(name);
+ HloModuleConfig config;
+ module_ = MakeUnique<HloModule>(name, config);
}
// Similar to the single entry computation constructor above, but runs the
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 3d055b327e..81cc7c4bdc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -370,8 +370,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
@@ -412,8 +412,8 @@ TEST_F(FusionCostAnalysis, NoLayout) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
shape_with_layout, HloOpcode::kAdd, c1, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add, broadcast}, HloInstruction::FusionKind::kLoop);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 6b681a5bf6..7e7c4f95fe 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -19,27 +19,32 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
using tensorflow::gtl::ArraySlice;
-std::unique_ptr<HloModule> CreateModuleWithProgramShape(
- PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims, HloInstruction** param,
- HloComputation** entry_computation) {
- Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
- Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims);
- std::unique_ptr<HloModule> module = MakeUnique<HloModule>("test");
- *entry_computation = module->AddEntryComputation(
- CreateComputationWithSignature({&input_shape}, output_shape, "entry")
- .ValueOrDie());
- *param = (*entry_computation)->parameter_instruction(0);
- return module;
-}
-
-TEST(HloCreationUtilsTest, CollapseFirst1Dim) {
+class HloCreationUtilsTest : public HloTestBase {
+ protected:
+ static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
+ PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
+ ArraySlice<int64> output_shape_dims, HloInstruction** param,
+ HloComputation** entry_computation) {
+ Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
+ Shape output_shape =
+ ShapeUtil::MakeShape(primitive_type, output_shape_dims);
+ auto module = CreateNewModule("test");
+ *entry_computation = module->AddEntryComputation(
+ CreateComputationWithSignature({&input_shape}, output_shape, "entry")
+ .ValueOrDie());
+ *param = (*entry_computation)->parameter_instruction(0);
+ return module;
+ }
+};
+
+TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -59,7 +64,7 @@ TEST(HloCreationUtilsTest, CollapseFirst1Dim) {
CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({3, 4}));
}
-TEST(HloCreationUtilsTest, CollapseFirst2Dims) {
+TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -84,7 +89,7 @@ TEST(HloCreationUtilsTest, CollapseFirst2Dims) {
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
-TEST(HloCreationUtilsTest, Prepend1DegenerateDim) {
+TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -104,7 +109,7 @@ TEST(HloCreationUtilsTest, Prepend1DegenerateDim) {
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9, 10}}));
}
-TEST(HloCreationUtilsTest, Prepend2DegenerateDims) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -124,7 +129,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDims) {
CHECK_EQ(*result_literal, *Literal::CreateR3<int32>({{{9, 10}}}));
}
-TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
+TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -144,7 +149,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9}}));
}
-TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
+TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -166,7 +171,7 @@ TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
*Literal::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
-TEST(HloCreationUtilsTest, PadVectorWithZeros) {
+TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -187,7 +192,7 @@ TEST(HloCreationUtilsTest, PadVectorWithZeros) {
CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
-TEST(HloCreationUtilsTest, BroadcastZeros_S32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloInstruction* param;
HloComputation* entry_computation;
@@ -208,7 +213,7 @@ TEST(HloCreationUtilsTest, BroadcastZeros_S32) {
CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{0, 0}, {0, 0}}));
}
-TEST(HloCreationUtilsTest, BroadcastZeros_F32) {
+TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloInstruction* param;
HloComputation* entry_computation;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index f1dcef1dfc..8cf94123b7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -2968,9 +2968,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
+ HloModuleConfig config;
// Attach cloned computation to an empty HLO module so the existing ones are
// not modified.
- HloModule empty_hlo_module("EmptyModuleForFusion");
+ HloModule empty_hlo_module("EmptyModuleForFusion", config);
auto cloned_fused_computation =
fusion->fused_instructions_computation()->Clone(
/*suffix=*/"clone_with_layout", &empty_hlo_module);
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 1f00aa41dc..b589cd573d 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -47,7 +48,9 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
-TEST(HloGraphDumperTest, NestedFusion) {
+class HloGraphDumperTest : public HloTestBase {};
+
+TEST_F(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
@@ -64,10 +67,9 @@ TEST(HloGraphDumperTest, NestedFusion) {
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
}
-
- HloModule m(TestName());
- m.AddEntryComputation(b.Build());
- HloComputation* root_computation = m.entry_computation();
+ auto m = CreateNewModule();
+ m->AddEntryComputation(b.Build());
+ HloComputation* root_computation = m->entry_computation();
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
auto* outer_fusion = root_computation->CreateFusionInstruction(
@@ -117,13 +119,13 @@ TEST(HloGraphDumperTest, NestedFusion) {
HasSubstr(inner_sum->name()));
}
-TEST(HloGraphDumperTest, Constant) {
+TEST_F(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
instruction->set_name("i_am_a_constant_root_instruction");
- HloModule m(TestName());
- HloComputation* root_computation = m.AddEntryComputation(b.Build());
+ auto m = CreateNewModule();
+ HloComputation* root_computation = m->AddEntryComputation(b.Build());
string graph = hlo_graph_dumper::DumpGraph(
*root_computation, /*label=*/"an_empty_graph", DebugOptions());
EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index f2980d309d..5b65b1152c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -149,8 +149,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) {
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
@@ -186,8 +186,8 @@ TEST_F(HloInstructionTest, MultipleUsers) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(1, bar->user_count());
@@ -219,8 +219,8 @@ TEST_F(HloInstructionTest, RepeatedUser) {
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(1, foo->user_count());
@@ -254,8 +254,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
auto addtotal = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(addtotal->Accept(&visitor));
@@ -303,8 +303,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
auto neg2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(neg2->Accept(&visitor));
@@ -325,7 +325,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
//
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
- HloModule module(TestName());
+ auto module = CreateNewModule();
// Builds an x+1.0 computation to use in a Map.
auto embedded_builder = HloComputation::Builder("f32+1");
@@ -335,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
- auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+ auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and feeds it to the map.
HloComputation::Builder builder(TestName());
@@ -343,7 +343,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
HloInstruction::CreateParameter(0, f32a100x10, ""));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(map->Accept(&visitor));
@@ -373,8 +373,8 @@ TEST_F(HloInstructionTest, TrivialReduce) {
HloInstruction::CreateParameter(1, r0f32, "y"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
- HloModule module(TestName());
- auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
+ auto module = CreateNewModule();
+ auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and an initial value and feeds them to the reduce.
HloComputation::Builder builder(TestName());
@@ -387,7 +387,7 @@ TEST_F(HloInstructionTest, TrivialReduce) {
auto reduce = builder.AddInstruction(
HloInstruction::CreateReduce(f32v100, param0, const0,
/*dimensions_to_reduce=*/{1}, add_f32));
- module.AddEntryComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(reduce->Accept(&visitor));
@@ -414,8 +414,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
@@ -449,8 +449,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
@@ -477,8 +477,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
@@ -514,8 +514,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
@@ -544,8 +544,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(2, bar->user_count());
@@ -609,8 +609,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
NodeCollectorAndPostProcessor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
@@ -627,8 +627,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
@@ -645,8 +645,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add}, HloInstruction::FusionKind::kLoop);
@@ -667,8 +667,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
@@ -690,8 +690,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp2, exp1}, HloInstruction::FusionKind::kLoop);
@@ -746,13 +746,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
- HloModule module(TestName());
+ auto module = CreateNewModule();
auto make_map_computation = [&]() {
auto builder = HloComputation::Builder("FusionMap");
builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
- return module.AddEmbeddedComputation(builder.Build());
+ return module->AddEmbeddedComputation(builder.Build());
};
HloComputation* computation_x = make_map_computation();
@@ -767,7 +767,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
@@ -814,8 +814,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
@@ -940,8 +940,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) {
HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
int visit_num = 0;
std::unordered_map<HloInstruction*, int> visit_order;
@@ -969,8 +969,8 @@ TEST_F(HloInstructionTest, FullyElementwise) {
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_TRUE(add->IsElementwise());
for (int i = 0; i < add->operand_count(); ++i) {
@@ -1013,8 +1013,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
HloInstruction* max = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@@ -1056,8 +1056,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@@ -1099,8 +1099,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
@@ -1118,7 +1118,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
}
TEST_F(HloInstructionTest, FusionEquality) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create two fusion instructions containing a single unary operation.
@@ -1128,7 +1128,7 @@ TEST_F(HloInstructionTest, FusionEquality) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
auto neg = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
auto* fusion2 = computation->CreateFusionInstruction(
@@ -1140,7 +1140,7 @@ TEST_F(HloInstructionTest, FusionEquality) {
}
TEST_F(HloInstructionTest, NestedFusionEquality) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Build a nested fusion computation.
@@ -1166,7 +1166,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
data_shape, HloOpcode::kSubtract, dot, add_operand));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
auto nested_fusion = computation->CreateFusionInstruction(
{dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
@@ -1244,8 +1244,8 @@ TEST_F(HloInstructionTest, Stringification) {
"%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
"%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
@@ -1295,8 +1295,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
/*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
@@ -1331,8 +1331,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
/*index_vector_dim=*/2),
/*window_bounds=*/{30, 29, 28, 27, 26}));
- HloModule module(TestName());
- module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 08b9a29aed..d4bad16f79 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -41,9 +41,6 @@ HloModule::HloModule(const string& name,
entry_computation_handle_(entry_computation_handle),
unique_id_(next_unique_module_id_++) {}
-HloModule::HloModule(const string& name)
- : name_(NameUniquer::GetSanitizedName(name)),
- unique_id_(next_unique_module_id_++) {}
HloModule::HloModule(const string& name, const HloModuleConfig& config)
: name_(NameUniquer::GetSanitizedName(name)),
config_(config),
@@ -479,8 +476,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
- auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
- module->config_ = config_;
+ auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
module->entry_computation_handle_ = entry_computation_handle_;
module->has_entry_computation_handle_ = has_entry_computation_handle_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 9f7f25202b..aa843ead51 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -55,7 +55,6 @@ class HloModule {
// only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
- explicit HloModule(const string& name);
explicit HloModule(const string& name, const HloModuleConfig& config);
// Adds an entry computation to the module. A module can only have one entry
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index caa1a111ad..c7c4160345 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -71,10 +71,10 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) {
HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
/*rhs=*/transpose_y, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(dot));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the fusion.
std::unordered_set<HloInstruction*> instruction_set(
@@ -114,10 +114,10 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
ShapeUtil::MakeShape(F32, {1, 3}),
/*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(dot));
+ FoldTranspose(module.get());
for (auto* instruction : entry_computation->instructions()) {
if (instruction->opcode() == HloOpcode::kFusion) {
@@ -149,10 +149,10 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
add->shape(), HloOpcode::kMultiply, add, sub));
- HloModule module("fuse_with_constant_operands");
+ auto module = CreateNewModule("fuse_with_constant_operands");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(mul));
- HloInstruction* call = module.OutlineExpressionFromComputation(
+ module->AddEntryComputation(builder.Build(mul));
+ HloInstruction* call = module->OutlineExpressionFromComputation(
{add, sub, mul}, "", entry_computation);
EXPECT_EQ(call, entry_computation->root_instruction());
HloComputation* callee_computation = call->to_apply();
@@ -182,14 +182,14 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
/*rhs=*/transpose_y, dot_dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(dot));
+ module->AddEntryComputation(builder.Build(dot));
- HloInstruction* call = module.OutlineExpressionFromComputation(
+ HloInstruction* call = module->OutlineExpressionFromComputation(
{transpose_y, dot}, "outlined", entry_computation);
- FoldTranspose(&module);
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the fusion.
std::unordered_set<HloInstruction*> instruction_set(
@@ -240,10 +240,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
@@ -293,10 +293,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
@@ -351,10 +351,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
@@ -415,10 +415,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
- HloModule module("test_module");
+ auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
- module.AddEntryComputation(builder.Build(conv));
- FoldTranspose(&module);
+ module->AddEntryComputation(builder.Build(conv));
+ FoldTranspose(module.get());
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
index 4f8cdc1e0e..a4e67cc9d9 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
@@ -46,9 +46,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase {
0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {}
StatusOr<bool> RunZeroSizedElimination() {
- HloModule module("zero_sized_elimination_test_module");
- module.AddEntryComputation(builder_.Build());
- return ZeroSizedHloElimination{}.Run(&module);
+ auto module = CreateNewModule("zero_sized_elimination_test_module");
+ module->AddEntryComputation(builder_.Build());
+ return ZeroSizedHloElimination{}.Run(module.get());
}
HloComputation::Builder builder_;