aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
blob: d12a4e7fcd7813775a81677bcaa07af60ff9b477 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"

namespace xla {
namespace {

class TrivialCrossReplicaSumTest : public HloTestBase {};

// Currently the CPU and GPU backends only support CrossReplicaSum with one
// replica.  But we can at least check this.

XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
  const char* module_str = R"(
  HloModule test

  add {
    x = f32[] parameter(0)
    y = f32[] parameter(1)
    add = f32[] add(x, y)
  }

  ENTRY test_computation {
    p = f32[3] parameter(0)
    ROOT crs = f32[3] cross-replica-sum(p), to_apply=add
  })";
  auto module =
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
  auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
  EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
}

XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
  const char* module_str = R"(
  HloModule test

  add {
    x = f32[] parameter(0)
    y = f32[] parameter(1)
    add = f32[] add(x, y)
  }

  ENTRY test_computation {
    p0 = f32[3] parameter(0)
    p1 = f32[2] parameter(1)
    ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
  })";
  auto module =
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
  auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
  auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
  EXPECT_EQ(
      *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
      *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
}

// On the GPU backend, constants get special handling.  Someone might pass a
// constant to CRS to e.g. count the number of replicas -- we need to make sure
// it works.
XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
  const char* module_str = R"(
  HloModule test

  add {
    x = f32[] parameter(0)
    y = f32[] parameter(1)
    add = f32[] add(x, y)
  }

  ENTRY test_computation {
    p0 = f32[3] parameter(0)
    p1 = f32[2] constant({10, 20})
    ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
  })";
  auto module =
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
  auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
  auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
  EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
            *ExecuteAndTransfer(std::move(module), {literal0.get()}));
}

}  // namespace
}  // namespace xla