diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/cross_replica_sum_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/cross_replica_sum_test.cc | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index c960b3c15f..b151187c4b 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -32,9 +32,16 @@ class TrivialCrossReplicaSumTest : public HloTestBase {}; XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -45,10 +52,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -65,10 +79,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); |