diff options
author | 2018-06-05 16:23:23 -0700 | |
---|---|---|
committer | 2018-06-05 16:27:53 -0700 | |
commit | a57f0de68685fb537eb390fa87f04dbafecb28ef (patch) | |
tree | 1796b4af3d5e28164f7da755d0a47b96dbd01894 /tensorflow/compiler/xla/service/hlo_parser_test.cc | |
parent | 135a25971bfbac86b0aed2cf0433608966015c22 (diff) |
[XLA] Make CrossReplicaSum support general cross replica reduce. Also change the interface to be able to describe the common AllReduce semantic.
PiperOrigin-RevId: 199376926
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 84a981675f..08068dc504 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -900,6 +900,24 @@ ENTRY Gather { )" }, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + +)" +}, }); // clang-format on } |