aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 8a1e705711..1a861cd16b 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -67,7 +67,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
vec1_shape_, concat_1, /*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
/*dimension=*/0));
@@ -75,7 +76,8 @@ class HloRematerializationTest : public HloTestBase {
// which is necessary to use this computation in a while.
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
/*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
return builder.Build();
}
@@ -103,7 +105,8 @@ class HloRematerializationTest : public HloTestBase {
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
vec1_shape_, while_cond, while_body, slice_1));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
@@ -111,7 +114,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
/*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
return builder.Build();
}
@@ -353,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}));
+ /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
@@ -469,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}));
+ /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}