aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/cross_replica_sum_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc27
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index c960b3c15f..b151187c4b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -32,9 +32,16 @@ class TrivialCrossReplicaSumTest : public HloTestBase {};
XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p = f32[3] parameter(0)
- ROOT crs = f32[3] cross-replica-sum(p)
+ ROOT crs = f32[3] cross-replica-sum(p), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
@@ -45,10 +52,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] parameter(1)
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
@@ -65,10 +79,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] constant({10, 20})
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();