aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/map_test.cc
blob: 2ef392508d14cf6dc14b2c979f07a79bc60d7426 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"

namespace xla {
namespace {

class MapTest : public ClientLibraryTestBase {
 public:
  explicit MapTest(perftools::gputools::Platform* platform = nullptr)
      : ClientLibraryTestBase(platform) {
    mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
    mutable_debug_options()->add_xla_disable_hlo_passes("inline");
  }

  // Creates a function that adds its scalar argument with the constant 1.0.
  //
  // x {R0F32} ----> (add)
  //                /
  // 1.0f ---------/
  Computation CreateAdderToOne() {
    ComputationBuilder mapped_builder(client_, TestName());
    auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto one = mapped_builder.ConstantR0<float>(1.0);
    auto adder_to_one = mapped_builder.Add(x, one);
    auto computation_status = mapped_builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  Computation CreateMax() {
    ComputationBuilder b(client_, TestName());
    auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    b.Max(lhs, rhs);
    auto computation_status = b.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a computation that accepts an F32 and returns T(1) (ignoring the
  // argument).
  template <class T>
  Computation CreateScalarOne() {
    ComputationBuilder mapped_builder(client_, "scalar_one");
    (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    mapped_builder.ConstantR0<T>(1);
    auto computation_status = mapped_builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a function that multiplies its scalar argument by the constant 2.0
  //
  // x {R0F32} ----> (mul)
  //                /
  // 2.0f ---------/
  Computation CreateMulByTwo() {
    ComputationBuilder mapped_builder(client_, TestName());
    auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto two = mapped_builder.ConstantR0<float>(2.0);
    auto mul_by_two = mapped_builder.Mul(x, two);
    auto computation_status = mapped_builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a function that adds its scalar argument with the constant 1.0 and
  // then multiplies by the original element.
  //
  //           /------------------|
  //          /                   |
  // x {R0F32} ----> (add) ----> (mul)
  //                /
  // 1.0f ---------/
  Computation CreateAdderToOneTimesItself() {
    ComputationBuilder mapped_builder(client_, TestName());
    auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto one = mapped_builder.ConstantR0<float>(1.0);
    auto adder_to_one = mapped_builder.Add(x, one);
    auto result = mapped_builder.Mul(x, adder_to_one);
    auto computation_status = mapped_builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a function that takes a single parameter and calls map with
  // "embedded_computation" on it, and then adds "n" to the result.
  //
  // x {R0F32} -----------> (map) ----> (add)
  //                         /           /
  // embedded_computation --/       n --/
  Computation CreateMapPlusN(const Computation& embedded_computation, float n) {
    ComputationBuilder builder(client_, TestName());
    auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto map = builder.Map({x}, embedded_computation, {});
    auto constant_n = builder.ConstantR0<float>(n);
    auto add = builder.Add(map, constant_n);
    auto computation_status = builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a binary function with signature (F32, F32) -> Pred
  // defined by (x, y) -> x > y.
  Computation CreateGt() {
    ComputationBuilder b(client_, "Gt");
    auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    auto gt = b.Gt(x, y);
    auto computation_status = b.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }

  // Creates a function that adds three scalar arguments
  //
  // x {R0F32} -------|
  //                  |
  // y {R0F32} ----> (add) ---> (add)
  //                           /
  // z {R0F32} ---------------/
  Computation CreateTernaryAdder() {
    ComputationBuilder mapped_builder(client_, "TernaryAdder");
    auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z");
    auto xy = mapped_builder.Add(x, y);
    auto xyz = mapped_builder.Add(xy, z);
    auto computation_status = mapped_builder.Build();
    TF_CHECK_OK(computation_status.status());
    return computation_status.ConsumeValueOrDie();
  }
};

TEST_F(MapTest, MapEachElemPlusOneR0) {
  // Applies lambda (x) (+ x 1)) to an input scalar.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(42.0);
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateAdderToOne(), {});

  ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
                             ErrorSpec(0.01f));
}

XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
  // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateAdderToOne(), {0});

  ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
                             ErrorSpec(0.01f));
}

TEST_F(MapTest, MapEachElemPlusOneR1S4) {
  // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateAdderToOne(), {0});

  ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
                             {param0_data.get()}, ErrorSpec(0.01f));
}

TEST_F(MapTest, MapEachF32ElementToS32Constant) {
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateScalarOne<int32>(), {0});

  ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}

TEST_F(MapTest, MapEachF32ElementToU32Constant) {
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateScalarOne<uint32>(), {0});

  ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}

TEST_F(MapTest, MapEachElemLongerChainR1) {
  // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0});

  ComputeAndCompareR1<float>(
      &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
      {param0_data.get()}, ErrorSpec(0.01f));
}

XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
  // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
  // maps (lambda (x) (* x 2)) on the result.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
  auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});

  ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
                             ErrorSpec(0.01f));
}

TEST_F(MapTest, MapMultipleMapsR1S4) {
  // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
  // maps (lambda (x) (* x 2)) on the result.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
  auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});

  ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
                             {param0_data.get()}, ErrorSpec(0.01f));
}

TEST_F(MapTest, MapEachElemPlusOneR2) {
  // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
      {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param = builder.Parameter(0, param0_literal->shape(), "param0");
  auto map = builder.Map({param}, CreateAdderToOne(), {0, 1});

  Array2D<float> expected_array(
      {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
  ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
                             ErrorSpec(0.01f));
}

XLA_TEST_F(MapTest, ComplexNestedMaps) {
  // Constructs a complex graph of embedded computations to test the computation
  // lowering order. Python equivalent:
  //
  //   embed1 = lambda x: x + 1                  #  x + 1
  //   embed2 = lambda x: embed1(x) + 2          #  x + 3
  //   embed3 = lambda x: embed1(x) + 4          #  x + 5
  //   embed4 = lambda x: embed2(x) + embed3(x)  # 2x + 8
  //   embed5 = lambda x: embed2(x) + 6          #  x + 9
  //   result = embed5(42) + embed4(7)           # (42 + 9) + (2 * 7 + 8) = 73

  Shape scalar_shape = ShapeUtil::MakeShape(F32, {});

  auto embed1 = CreateAdderToOne();
  auto embed2 = CreateMapPlusN(embed1, 2.0);
  auto embed3 = CreateMapPlusN(embed1, 4.0);

  ComputationBuilder embed4_builder(client_, "embed4");
  auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
  auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {});
  auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {});
  auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
  auto embed4_status = embed4_builder.Build();
  ASSERT_IS_OK(embed4_status.status());
  auto embed4 = embed4_status.ConsumeValueOrDie();

  auto embed5 = CreateMapPlusN(embed2, 6.0);

  ComputationBuilder builder(client_, TestName());
  auto constant_42 = builder.ConstantR0<float>(42.0);
  auto constant_7 = builder.ConstantR0<float>(7.0);
  auto map_42 = builder.Map({constant_42}, embed5, {});
  auto map_7 = builder.Map({constant_7}, embed4, {});
  builder.Add(map_42, map_7);

  ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
}

TEST_F(MapTest, VersionedEmbeddedComputation) {
  // Build a computation X, use it in a map, then add an additional operation to
  // computation X and use it again in a different map. Verify that the proper
  // versions of computation X are used in each of the maps.

  // Create a (embedded) computation which adds one to its parameter argument.
  ComputationBuilder embedded_builder(client_, "EmbeddedComputation");
  auto param_0 =
      embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
  auto constant_one = embedded_builder.ConstantR0<float>(1.0);
  auto adder_to_one = embedded_builder.Add(param_0, constant_one);
  auto computation_status = embedded_builder.Build();
  ASSERT_IS_OK(computation_status.status());
  auto embedded_computation = computation_status.ConsumeValueOrDie();

  ComputationBuilder builder(client_, TestName());
  auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
  auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0});

  // Add another Add(1) operation to the existing embedded computation. This
  // requires using the stub interface because the ComputationBuilder does not
  // allow modification to the Computation objects after they have been built.
  BinaryOpRequest request;
  request.set_binop(BINOP_ADD);
  *request.mutable_lhs() = adder_to_one;
  *request.mutable_rhs() = constant_one;
  OpRequest op_request;
  *op_request.mutable_computation() = embedded_computation.handle();
  *op_request.mutable_binary_op_request() = request;
  OpResponse response;
  tensorflow::Status s = client_->stub()->Op(&op_request, &response);
  ASSERT_TRUE(s.ok());

  auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0});

  // The original vector has Add(1) applied to it with a map, followed by
  // Add(1+1) resulting in a net Add(3).
  ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {},
                             ErrorSpec(0.01f));
}

TEST_F(MapTest, MapBinaryAdder) {
  // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  std::unique_ptr<Literal> param1_literal =
      Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  auto map = builder.Map({param0, param1},
                         CreateScalarAddComputation(F32, &builder), {0});

  ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
                             {param0_data.get(), param1_data.get()},
                             ErrorSpec(0.01f));
}

// Adds two rank-2 arrays with different layouts. This test exercises a path
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  std::unique_ptr<Literal> param1_literal =
      test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  auto map = builder.Map({param0, param1},
                         CreateScalarAddComputation(S32, &builder), {0, 1});

  Array2D<int32> expected(2, 2);
  expected(0, 0) = 11;
  expected(0, 1) = 22;
  expected(1, 0) = 33;
  expected(1, 1) = 44;
  ComputeAndCompareR2<int32>(&builder, expected,
                             {param0_data.get(), param1_data.get()});
}

XLA_TEST_F(MapTest, AddR3_3x0x2) {
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  std::unique_ptr<Literal> param1_literal =
      Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  auto map = builder.Map({param0, param1},
                         CreateScalarAddComputation(S32, &builder), {0, 1, 2});

  ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
                             {param0_data.get(), param1_data.get()});
}

TEST_F(MapTest, MapTernaryAdder) {
  // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
  ComputationBuilder builder(client_, TestName());
  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  std::unique_ptr<Literal> param1_literal =
      Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
  std::unique_ptr<Literal> param2_literal =
      Literal::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
  std::unique_ptr<GlobalData> param2_data =
      client_->TransferToServer(*param2_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
  auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0});

  ComputeAndCompareR1<float>(
      &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
      {param0_data.get(), param1_data.get(), param2_data.get()},
      ErrorSpec(0.01f));
}

TEST_F(MapTest, MapGt) {
  // Maps (x,y) -> x > y onto two R1F32 vectors.
  ComputationBuilder b(client_, TestName());
  auto gt = CreateGt();
  b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0});
  ComputeAndCompareR1<bool>(&b, {false, true}, {});
}

TEST_F(MapTest, NestedBinaryMap) {
  Computation max_with_square;
  {
    // max_with_square(x) = do max(x, x^2) via a map.
    ComputationBuilder b(client_, "max_with_square");
    auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    b.Map({x, b.Mul(x, x)}, CreateMax(), {});
    auto computation_status = b.Build();
    ASSERT_IS_OK(computation_status.status());
    max_with_square = computation_status.ConsumeValueOrDie();
  }
  ComputationBuilder b(client_, TestName());
  auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
  b.Map({input}, max_with_square, {0});
  ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
}

TEST_F(MapTest, MapOperantionWithBuildError) {
  // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
  // type combination (F32 + U16) to test that the error is reported to the
  // outermost ComputationBuilder.
  ComputationBuilder builder(client_, TestName());

  auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
  auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
  auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y");
  auto adder = sub_builder->Add(x, y);
  auto error_add = sub_builder->BuildAndNoteError();

  std::unique_ptr<Literal> param0_literal =
      Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  std::unique_ptr<Literal> param1_literal =
      Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  auto map = builder.Map({param0, param1}, error_add, {0});

  StatusOr<Computation> computation_status = builder.Build();
  ASSERT_TRUE(!computation_status.ok());
  EXPECT_THAT(
      computation_status.status().ToString(),
      ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with "
                           "different element types: f32[] and u16[]"));
}

// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
// optimizations.
using MapTestWithFullOpt = ClientLibraryTestBase;

// Regression test for b/31466798. The inliner simplifies map(param0, param1,
// power) to power(param0, param1) without deleting the old subcomputation which
// is the same as the new entry computation. HloSubcomputationUnification used
// to have issues with such patterns and maybe invalidate the pointer to entry
// computation.
TEST_F(MapTestWithFullOpt, MapScalarPower) {
  ComputationBuilder builder(client_, TestName());

  auto sub_builder = builder.CreateSubBuilder("power");
  auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
  auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
  sub_builder->Pow(x, y);
  auto power = sub_builder->BuildAndNoteError();

  std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
  std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  builder.Map({param0, param1}, power, {});

  ComputeAndCompareR0<float>(&builder, 32.0f,
                             {param0_data.get(), param1_data.get()},
                             ErrorSpec(0.01f));
}

// Regression test for b/35786417, where the inliner would not notice the change
// of parameter order inside the map.
TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
  ComputationBuilder builder(client_, TestName());

  auto sub_builder = builder.CreateSubBuilder("power");
  auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
  auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
  sub_builder->Sub(y, x);  // note that this is y - x, not x - y
  auto sub_opposite = sub_builder->BuildAndNoteError();

  std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
  std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  std::unique_ptr<GlobalData> param1_data =
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
  builder.Map({param0, param1}, sub_opposite, {});

  ComputeAndCompareR0<float>(
      &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
}

// Regression test for b/35786417, where the inliner would CHECK-fail due to the
// mul inside the map having more parameters than the map does.
TEST_F(MapTestWithFullOpt, MapSquare) {
  ComputationBuilder builder(client_, TestName());

  auto sub_builder = builder.CreateSubBuilder("power");
  auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
  sub_builder->Mul(x, x);
  auto square = sub_builder->BuildAndNoteError();

  std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(10.0f);
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();

  auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
  builder.Map({param0}, square, {});

  ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
                             ErrorSpec(0.01f));
}

}  // namespace
}  // namespace xla