aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/tuple_test.cc
blob: cea9316a6d684bb8512ceac62bd2e7e666fb934e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <initializer_list>
#include <memory>

#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"

namespace xla {
namespace {

class TupleTest : public ClientLibraryTestBase {
 public:
  ErrorSpec error_spec_{0.0001};
};

// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
  ComputationBuilder builder(client_, TestName());

  const float constant_scalar = 7.3f;
  std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
  std::initializer_list<std::initializer_list<float>> constant_matrix = {
      {1.1f, 2.2f, 3.5f},  // row 0
      {4.8f, 5.0f, 6.7f},  // row 1
  };
  auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
                               builder.ConstantR1<float>(constant_vector),
                               builder.ConstantR2<float>(constant_matrix)});

  auto expected = LiteralUtil::MakeTuple(
      {LiteralUtil::CreateR0<float>(constant_scalar).get(),
       LiteralUtil::CreateR1<float>(constant_vector).get(),
       LiteralUtil::CreateR2<float>(constant_matrix).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
  ComputationBuilder builder(client_, TestName());

  auto result = builder.Tuple(
      {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});

  auto expected =
      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
                              LiteralUtil::CreateR1<float>({}).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
  ComputationBuilder builder(client_, TestName());
  auto result = builder.Tuple({});
  auto expected = LiteralUtil::MakeTuple({});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElement) {
  ComputationBuilder builder(client_, TestName());
  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
  std::initializer_list<std::initializer_list<float>> constant_matrix = {
      {1.f, 2.f, 3.f},  // row 0
      {4.f, 5.f, 6.f},  // row 1
  };
  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                   builder.ConstantR2<float>(constant_matrix)});
  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
  ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
                             error_spec_);
}

// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
  ComputationBuilder builder(client_, TestName());
  auto tuple_data = builder.Tuple(
      {builder.ConstantR1<float>({}),
       builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}

// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {
  ComputationBuilder builder(client_, TestName());
  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
  std::initializer_list<std::initializer_list<float>> constant_matrix = {
      {1.f, 2.f, 3.f},  // row 0
      {4.f, 5.f, 6.f},  // row 1
  };
  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                   builder.ConstantR2<float>(constant_matrix)});
  auto vector_element = builder.GetTupleElement(tuple_data, 0);
  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
  auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
  auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
  auto result = builder.Add(matrix_element, vector_element,
                            /*broadcast_dimensions=*/{1});

  Array2D<float> expected({
      {2.f, 4.f, 6.f},  // row 0
      {5.f, 7.f, 9.f},  // row 1
  });
  ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
  ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}

// Extracts both elements from a tuple and then puts them into a new tuple in
// the opposite order.
XLA_TEST_F(TupleTest, TupleGTEToTuple) {
  ComputationBuilder builder(client_, TestName());
  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
  std::initializer_list<std::initializer_list<float>> constant_matrix = {
      {1.f, 2.f, 3.f},  // row 0
      {4.f, 5.f, 6.f},  // row 1
  };
  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                   builder.ConstantR2<float>(constant_matrix)});
  auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
                                  builder.GetTupleElement(tuple_data, 0)});
  auto expected = LiteralUtil::MakeTuple(
      {LiteralUtil::CreateR2<float>(constant_matrix).get(),
       LiteralUtil::CreateR1<float>(constant_vector).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

// Builds two new tuples from an existing tuple (by means of GetTupleElement),
// then adds up the components of the new tuples.
XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
  //
  // v------           --(GTE 0)--             --(GTE 0)----------
  //        \         /           \           /                   \
  //         (tuple)--             (tuple01)--                     \
  //        /   |     \           /           \                     \
  // m------    |      --(GTE 1)--             --(GTE 1)------------ \
  //            |                                                   \ \
  //            |                                                    (add)
  //            |                                                   / /
  //            |--------(GTE 1)--             --(GTE 0)------------ /
  //             \                \           /                     /
  //              \                (tuple10)--                     /
  //               \              /           \                   /
  //                -----(GTE 0)--             --(GTE 1)----------
  ComputationBuilder builder(client_, TestName());
  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
  std::initializer_list<std::initializer_list<float>> constant_matrix = {
      {1.f, 2.f, 3.f},  // row 0
      {4.f, 5.f, 6.f},  // row 1
  };
  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                   builder.ConstantR2<float>(constant_matrix)});
  auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0),
                                    builder.GetTupleElement(tuple_data, 1)});
  auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
                                    builder.GetTupleElement(tuple_data, 0)});
  auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0);
  auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1);
  auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1);
  auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0);

  auto addvectors = builder.Add(vector_from_01, vector_from_10);
  auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);

  auto result = builder.Add(addmatrices, addvectors,
                            /*broadcast_dimensions=*/{1});

  Array2D<float> expected({
      {4.f, 8.f, 12.f},    // row 0
      {10.f, 14.f, 18.f},  // row 1
  });
  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}

XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
  // Tests a selection between tuples with "false" path taken.
  ComputationBuilder builder(client_, TestName());

  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
  auto tuple12 = builder.Tuple(
      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
  auto tuple21 = builder.Tuple(
      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});

  auto select =
      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
  auto expected =
      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
                              LiteralUtil::CreateR1<float>(vec1).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

XLA_TEST_F(TupleTest, TuplesInAMap) {
  Computation tuple_computation;
  {
    // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
    //
    // Need to put a select in there to prevent HLO-level optimizations from
    // optimizing out the tuples.
    ComputationBuilder b(client_, "sort_square");
    auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto x2 = b.Mul(x, x);
    auto x_smaller_tuple = b.Tuple({x, x2});
    auto x2_smaller_tuple = b.Tuple({x2, x});
    auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
    auto smaller = b.GetTupleElement(sorted, 0);
    auto greater = b.GetTupleElement(sorted, 1);
    b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller));
    auto computation_status = b.Build();
    ASSERT_IS_OK(computation_status.status());
    tuple_computation = computation_status.ConsumeValueOrDie();
  }

  ComputationBuilder b(client_, TestName());
  auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
  b.Map({input}, tuple_computation);
  ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}

XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
  // Tests a selection between tuples with "true" path taken.
  ComputationBuilder builder(client_, TestName());

  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
  auto tuple12 = builder.Tuple(
      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
  auto tuple21 = builder.Tuple(
      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});

  auto select =
      builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
  auto expected =
      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
                              LiteralUtil::CreateR1<float>(vec2).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
  // Tests a selection between tuples but the final result is an element of the
  // tuple, not the whole tuple.
  ComputationBuilder builder(client_, TestName());

  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
  auto tuple12 = builder.Tuple(
      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
  auto tuple21 = builder.Tuple(
      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});

  auto select =
      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
  auto element = builder.GetTupleElement(select, 0);

  ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}

// Cascaded selects between tuple types.
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
  //
  //                       vec1     vec2   vec2     vec1
  //                        |        |      |        |
  //                        |        |      |        |
  //                        (tuple 12)      (tuple 21)
  //                               \            /
  //                                \          /
  //                                 \        /
  //  true  --            --(GTE 0)--(select 1)
  //          \          /             |
  //       (pred tuple)--              |          --(GTE 0)--
  //          /          \             V         /           \
  //  false --            --(GTE 1)--(select 2)--             --(add)
  //                                 /           \           /
  //                                /             --(GTE 1)--
  //                               /
  //                          (tuple 21)
  ComputationBuilder builder(client_, TestName());

  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};

  auto pred_tuple = builder.Tuple(
      {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)});
  auto tuple12 = builder.Tuple(
      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
  auto tuple21 = builder.Tuple(
      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});

  auto select1 =
      builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
  auto select2 =
      builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
  auto result = builder.Add(builder.GetTupleElement(select2, 0),
                            builder.GetTupleElement(select2, 1));

  ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}

XLA_TEST_F(TupleTest,
           DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
  // Similar to SelectBetweenTuples, but the constants are shared between the
  // input tuples.
  ComputationBuilder builder(client_, TestName());

  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
  auto c1 = builder.ConstantR1<float>(vec1);
  auto c2 = builder.ConstantR1<float>(vec2);
  auto tuple12 = builder.Tuple({c1, c2});
  auto tuple21 = builder.Tuple({c2, c1});

  auto select =
      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
  auto expected =
      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
                              LiteralUtil::CreateR1<float>(vec1).get()});
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

XLA_TEST_F(TupleTest, NestedTuples) {
  ComputationBuilder builder(client_, TestName());
  auto inner_tuple = builder.Tuple(
      {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
  auto outer_tuple =
      builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});

  auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
  auto expected_s = LiteralUtil::CreateR0<float>(42.0);
  auto expected_inner_tuple =
      LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
  auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
  auto expected =
      LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});

  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}

XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
  ComputationBuilder builder(client_, TestName());

  Shape data_shape = ShapeUtil::MakeShape(F32, {3});
  Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
  Shape outer_tuple_shape =
      ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});

  auto input = builder.Parameter(0, outer_tuple_shape, "input");
  auto gte0 = builder.GetTupleElement(input, 0);
  auto gte1 = builder.GetTupleElement(gte0, 1);
  builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0}));

  std::unique_ptr<GlobalData> data =
      client_
          ->TransferToServer(*LiteralUtil::MakeTuple({
              LiteralUtil::MakeTuple(
                  {
                      LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
                      LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
                  })
                  .get(),
              LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
          }))
          .ConsumeValueOrDie();

  std::vector<GlobalData*> arguments = {data.get()};
  const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
  ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
}

}  // namespace
}  // namespace xla

int main(int argc, char** argv) {
  std::vector<tensorflow::Flag> flag_list;
  xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
  xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
  if (!parse_result) {
    LOG(ERROR) << "\n" << usage;
    return 2;
  }
  testing::InitGoogleTest(&argc, argv);
  if (argc > 1) {
    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    return 2;
  }
  return RUN_ALL_TESTS();
}