diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization_test.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 8a1e705711..1a861cd16b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -67,7 +67,8 @@ class HloRematerializationTest : public HloTestBase { /*dimension=*/0)); auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( vec1_shape_, concat_1, /*start_indices=*/{0}, - /*limit_indices=*/{1})); + /*limit_indices=*/{1}, + /*strides=*/{1})); auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, /*dimension=*/0)); @@ -75,7 +76,8 @@ class HloRematerializationTest : public HloTestBase { // which is necessary to use this computation in a while. builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, /*start_indices=*/{0}, - /*limit_indices=*/{1})); + /*limit_indices=*/{1}, + /*strides=*/{1})); return builder.Build(); } @@ -103,7 +105,8 @@ class HloRematerializationTest : public HloTestBase { HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, - /*limit_indices=*/{1})); + /*limit_indices=*/{1}, + /*strides=*/{1})); auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( vec1_shape_, while_cond, while_body, slice_1)); auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -111,7 +114,8 @@ class HloRematerializationTest : public HloTestBase { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1})); + /*limit_indices=*/{1}, + /*strides=*/{1})); return builder.Build(); } @@ -353,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024})); + /*limit_indices=*/{1024}, /*slices=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } @@ -469,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024})); + /*limit_indices=*/{1024}, /*slices=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } |