aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc140
1 files changed, 75 insertions, 65 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index e7539759ce..cd735256b8 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) {
// be copied before entering the tuple.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
@@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) {
}
TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
- // Verify that an kCopy instructions which exist in the pass before
+ // Verify that kCopy instructions which change layout and exist before
// copy-insertion remain in the graph after copy-insertion.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
- HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
+ auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
+ Layout reversed_layout =
+ LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
+ Shape copy_shape = constant->shape();
+ *copy_shape.mutable_layout() = reversed_layout;
+ HloInstruction* copy_1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
+ HloInstruction* copy_2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
- HloInstruction* add_copy = builder.AddInstruction(
- HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
module->AddEntryComputation(builder.Build());
@@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 3);
+ EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))));
+ EXPECT_EQ(module->entry_computation()->root_instruction(), add);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
}
TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
@@ -162,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* x = builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
@@ -192,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
// the computation result. Verify that copies are added properly.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -204,9 +209,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
HloInstruction::CreateTuple({constant3, constant2}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
@@ -250,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
// The output of a bitcast is its operand (same buffer), so a bitcast
// constant feeding the result must have a copy added.
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0})));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 42.0})));
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
@@ -365,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
// copy is added.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -375,9 +381,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
HloInstruction::CreateTuple({constant2, constant1}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
HloInstruction* gte =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
@@ -408,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const Shape& loop_state_shape) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
auto induction_variable =
@@ -437,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// Update data GTE(1).
@@ -475,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -544,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
@@ -559,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
}
- auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
@@ -593,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
gte0->shape(), HloOpcode::kAdd, gte0, inc));
@@ -603,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// GTE(GTE(loop_state, 1), 0) -> Add
auto gte10 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
- auto update10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, gte10, update10));
@@ -628,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
bool nested = false) {
auto builder = HloComputation::Builder(TestName() + ".While");
auto induction_var_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
if (nested) {
auto inner_init = builder.AddInstruction(
@@ -654,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
auto builder = HloComputation::Builder(TestName() + ".While");
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
&builder);
}
@@ -672,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
@@ -684,9 +694,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
- nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2));
+ nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
data_init, &builder);
@@ -696,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto one_vec = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto data_init =
@@ -709,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto data_init = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
- auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto one_vec = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// Take a reference to 'data_init' to make it interfere with while result.
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data_init, one_vec));
@@ -745,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const bool nested =
ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
auto induction_var_init = builder->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto condition = module_->AddEmbeddedComputation(
BuildConditionComputation(loop_state_shape));
auto body = module_->AddEmbeddedComputation(
@@ -1247,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
-
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
@@ -1305,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1313,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1370,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1378,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1430,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1438,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
builder.AddInstruction(
@@ -1515,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1570,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -1639,7 +1649,7 @@ std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "loop_state"));
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kNot, constant));
return builder.Build();