aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser_test.cc
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2018-06-12 13:57:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 14:00:55 -0700
commit34b071f6b6a14bd4c8d5c30156c1670496b85f04 (patch)
treeb413892245024a28e63adf88078ad264f1cadd16 /tensorflow/compiler/xla/service/hlo_parser_test.cc
parent688a09dc6b70a81cae12a7e263515964311f8d86 (diff)
Support subgroup CrossReplicaSum
PiperOrigin-RevId: 200275384
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 1c5a47c875..f834d34d57 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -918,6 +918,24 @@ ENTRY CRS {
)"
},
+// cross-replica-sum with subgroups
+{
+"CrossReplicaSumWithSubgroups",
+R"(HloModule CRS_Subgroups
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CrossReplicaSumWithSubgroups {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc"
+}
+
+)"
+}
});
// clang-format on
}