aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 684fff8a6f..ed1a50f516 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1595,6 +1595,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
EXPECT_THAT(condition->root_instruction(), op::Constant());
}
+TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
+ string module_string = R"(
+HloModule TokensShouldNotBeCopied
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %generate-token = token[] generate-token(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %TokensShouldNotBeCopied () -> s32[] {
+ %one = s32[] constant(1)
+ %negative_one = s32[] negate(%one)
+ %init_token = token[] generate-token()
+ %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(
+ module_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should be no copies added because tokens should not be copied.
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
auto builder = HloComputation::Builder("trivial_condition");
builder.AddInstruction(