diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion_test.cc | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 684fff8a6f..ed1a50f516 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1595,6 +1595,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { EXPECT_THAT(condition->root_instruction(), op::Constant()); } +TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { + string module_string = R"( +HloModule TokensShouldNotBeCopied + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %generate-token = token[] generate-token(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokensShouldNotBeCopied () -> s32[] { + %one = s32[] constant(1) + %negative_one = s32[] negate(%one) + %init_token = token[] generate-token() + %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + HloRunner::CreateModuleFromString( + module_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should be no copies added because tokens should not be copied. + EXPECT_EQ(CountCopies(*module), 0); +} + std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) { auto builder = HloComputation::Builder("trivial_condition"); builder.AddInstruction( |