aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 16:23:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 16:27:53 -0700
commita57f0de68685fb537eb390fa87f04dbafecb28ef (patch)
tree1796b4af3d5e28164f7da755d0a47b96dbd01894 /tensorflow/compiler/xla/service/hlo_parser_test.cc
parent135a25971bfbac86b0aed2cf0433608966015c22 (diff)
[XLA] Make CrossReplicaSum support general cross replica reduce. Also change the interface to be able to describe the common AllReduce semantic.
PiperOrigin-RevId: 199376926
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 84a981675f..08068dc504 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -900,6 +900,24 @@ ENTRY Gather {
)"
},
+// cross-replica-sum
+{
+"CrossReplicaSum",
+R"(HloModule CRS
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CRS {
+ input = f32[8]{0} parameter(0)
+ ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add
+}
+
+)"
+},
});
// clang-format on
}