aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/compute_constant_test.cc
blob: d423c78476dde18d209b5efac9e8f77da41bfeb4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>
#include <vector>

#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"

namespace xla {
namespace {

// An enumerator for the client types that we want to iterate over in
// the various tests.
enum class ClientType { kLocal, kCompileOnly };
ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};

class ComputeConstantTest : public ::testing::Test {
 public:
  explicit ComputeConstantTest(
      perftools::gputools::Platform* platform = nullptr)
      : platform_(platform) {}

  string TestName() const {
    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
  }

  Client* ClientOrDie(::perftools::gputools::Platform* platform,
                      ClientType client_type) {
    if (client_type == ClientType::kLocal) {
      StatusOr<Client*> result =
          ClientLibrary::GetOrCreateLocalClient(platform);
      TF_CHECK_OK(result.status())
          << "could not create LocalClient for testing";
      return result.ValueOrDie();
    } else if (client_type == ClientType::kCompileOnly) {
      StatusOr<Client*> result =
          ClientLibrary::GetOrCreateCompileOnlyClient(platform);
      TF_CHECK_OK(result.status())
          << "could not create CompileOnlyClient for testing";
      return result.ValueOrDie();
    }
    LOG(FATAL) << "invalid client_type value";
  }

  StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
      Client* client, const ComputationDataHandle& operand,
      ComputationBuilder* builder, Layout* output_layout = nullptr,
      tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
    TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant(
                                           operand, output_layout, parameters));
    return std::move(computed);
  }

  template <class Scalar>
  StatusOr<Scalar> ComputeConstantScalar(
      Client* client, const ComputationDataHandle& operand,
      ComputationBuilder* builder,
      tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
    TF_ASSIGN_OR_RETURN(
        auto literal,
        ComputeConstantLiteral(client, operand, builder, nullptr, parameters));
    return literal->Get<Scalar>({});
  }

  bool IsConstant(const ComputationDataHandle& operand,
                  ComputationBuilder* builder, int64 num_parameters = 0) {
    StatusOr<bool> result = builder->IsConstant(operand, num_parameters);
    EXPECT_TRUE(result.ok()) << result.status();
    return result.ok() ? result.ValueOrDie() : false;
  }

  perftools::gputools::Platform* platform_;
};

TEST_F(ComputeConstantTest, ScalarInt32Literal) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation = b.ConstantR0<int32>(42);
    EXPECT_TRUE(IsConstant(computation, &b));

    auto value = ComputeConstantScalar<int32>(client, computation, &b);
    ASSERT_TRUE(value.ok()) << value.status();
    EXPECT_EQ(value.ValueOrDie(), 42);
  }
}

TEST_F(ComputeConstantTest, ScalarFloatAdd) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation =
        b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
    EXPECT_TRUE(IsConstant(computation, &b));

    auto value = ComputeConstantScalar<float>(client, computation, &b);
    ASSERT_TRUE(value.ok()) << value.status();
    EXPECT_EQ(value.ValueOrDie(), 44.0f);
  }
}

TEST_F(ComputeConstantTest, ScalarRng) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation =
        b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
                     ShapeUtil::MakeShape(F32, {}));
    EXPECT_FALSE(IsConstant(computation, &b));

    auto value = ComputeConstantScalar<float>(client, computation, &b);
    ASSERT_FALSE(value.ok())
        << "computing a RNG value should not be considered a constant";
  }
}

TEST_F(ComputeConstantTest, Param) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs");
    auto computation = b.Add(param, b.ConstantR0<float>(1.5f));

    std::vector<Literal> arguments;
    arguments.emplace_back(*Literal::CreateR0(42.5f));
    EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));

    auto value =
        ComputeConstantScalar<float>(client, computation, &b, arguments);
    ASSERT_TRUE(value.ok()) << value.status();
    EXPECT_EQ(value.ValueOrDie(), 44.0f);
  }
}

TEST_F(ComputeConstantTest, DirectParamMissing) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
    EXPECT_FALSE(IsConstant(computation, &b));

    auto value = ComputeConstantScalar<float>(client, computation, &b);
    EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
                    .contains("depends on parameter"))
        << value.status();
  }
}

TEST_F(ComputeConstantTest, IndirectParamMissing) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation =
        b.Add(b.ConstantR0<float>(1.0f),
              b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
    EXPECT_FALSE(IsConstant(computation, &b));

    auto value = ComputeConstantScalar<float>(client, computation, &b);
    EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
                    .contains("depends on parameter"))
        << value.status();
  }
}

// Test computation of an expression interspersed with param nodes but
// the expression does not depend on the param nodes.
TEST_F(ComputeConstantTest, UnrelatedParam) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());

    auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
    auto constant_4 =
        b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
    auto not_constant_a = b.Add(constant_4, param_a);

    auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
    auto constant_9 =
        b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
    auto not_constant_b = b.Add(param_b, constant_9);

    auto constant_13 = b.Add(constant_4, constant_9);
    b.Add(not_constant_b, b.Add(constant_13, not_constant_a));

    EXPECT_TRUE(IsConstant(constant_13, &b));

    auto value = ComputeConstantScalar<float>(client, constant_13, &b);
    ASSERT_TRUE(value.ok()) << value.status();
    EXPECT_EQ(value.ValueOrDie(), 13.0f);
  }
}

TEST_F(ComputeConstantTest, NonScalarAdd) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());

    auto computation =
        b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
    EXPECT_TRUE(IsConstant(computation, &b));

    auto computed = ComputeConstantLiteral(client, computation, &b);
    ASSERT_TRUE(computed.ok()) << computed.status();
    std::unique_ptr<Literal> expected_literal =
        Literal::CreateR1<int32>({4, 6});
    LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
  }
}

TEST_F(ComputeConstantTest, IntegerDivide) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());
    auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
    EXPECT_TRUE(IsConstant(computation, &b));

    auto computed = ComputeConstantLiteral(client, computation, &b);
    ASSERT_TRUE(computed.ok()) << computed.status();
    std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
    LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
  }
}

XLA_TEST_F(ComputeConstantTest, Layout) {
  for (ClientType client_type : client_types) {
    Client* client = ClientOrDie(platform_, client_type);
    ComputationBuilder b(client, TestName());

    std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
    for (const std::vector<int64>& layout : layouts) {
      auto layout_proto = LayoutUtil::MakeLayout(layout);
      auto computed = ComputeConstantLiteral(
          client,
          b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
                b.ConstantR2<int32>({{10, 20}, {30, 40}})),
          &b, &layout_proto);
      ASSERT_TRUE(computed.ok()) << computed.status();

      std::unique_ptr<Literal> expected_literal =
          test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
                                                       layout);
      LiteralTestUtil::AssertEqualShapesAndLayouts(
          expected_literal->shape(), computed.ValueOrDie()->shape());
      LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
    }
  }
}

}  // namespace
}  // namespace xla