aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/while_test.cc
blob: ccd2a95658928ba56a75c2c503e951015182a7b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"

namespace se = ::perftools::gputools;

namespace xla {
namespace {

class WhileTest : public ClientLibraryTestBase {};

// Tests a while node when the result type T is S32.
//
// int32 result = 0;
// while (result < 5) {
//   result = result + 1;
// }
TEST_F(WhileTest, WhileWithScalarResult) {
  auto result_shape = ShapeUtil::MakeShape(S32, {});

  // Create a computation for the condition: repeat for 5 iterations.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    builder.Gt(builder.ConstantR0<int32>(5), prev);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body: add 1 to the result variable.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto input = builder.ConstantR0<int32>(1);
    auto result = builder.Add(input, prev);
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, TestName());
  auto init = builder.ConstantR0<int32>(0);
  auto result = builder.While(condition, body, init);
  auto shape = builder.GetShape(result).ConsumeValueOrDie();

  ComputeAndCompareR0<int32>(&builder, 5, {});
}

TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
  auto result_shape = ShapeUtil::MakeShape(S32, {});
  auto orig_shape = ShapeUtil::MakeShape(S32, {2});

  // Create a computation for the condition: repeat for 5 iterations.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    builder.Gt(builder.ConstantR0<int32>(5), prev);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body: add 1 to the result variable.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto input = builder.ConstantR0<int32>(1);
    auto result = builder.Add(input, prev);
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, TestName());
  auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
                             builder.ConstantR0<int32>(0),
                             CreateScalarAddComputation(S32, &builder), {0});
  auto result = builder.While(condition, body, init);
  auto shape = builder.GetShape(result).ConsumeValueOrDie();

  ComputeAndCompareR0<int32>(&builder, 5, {});
}

TEST_F(WhileTest, WhileWithPredicateResult) {
  auto result_shape = ShapeUtil::MakeShape(PRED, {});

  // Create a computation for the condition: run until condition is true.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    builder.Ne(builder.ConstantR0<bool>(true), prev);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body: or condition with true.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto result = builder.LogicalOr(prev, builder.ConstantR0<bool>(true));
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, TestName());
  auto init = builder.Ne(builder.ConstantR0<bool>(false),
                         builder.ConstantR0<bool>(true));
  auto result = builder.While(condition, body, init);

  ComputeAndCompareR0<bool>(&builder, true, {});
}

// Tests a while node when the result type T is a vector.
//
// All constants are chosen to produce exact results.
// vector<float> result(0);
// while (result.sum() < 15.5f) {
//   result = result + vector<float>(0);
// }
// TODO(b/29185393): does not terminate on CPU.
TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
  Shape result_shape = ShapeUtil::MakeShape(F32, {0});

  // Create a computation for the reduction.
  Computation add;
  {
    ComputationBuilder builder(client_, "add");
    auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    builder.Add(x, y);
    add = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the condition.
  // Repeat until the sum of the result vector is less than 15.5f.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
                              /*dimensions_to_reduce=*/{0});
    auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body.
  // Add a constant vector of 1.f to the result vector.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto input = builder.ConstantR1<float>({});
    auto result = builder.Add(input, prev);
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.ConstantR1<float>({});
  auto result = builder.While(condition, body, init);
  VLOG(2) << "while = " << ShapeUtil::HumanString(
                               *builder.GetShape(result).ConsumeValueOrDie());

  ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
}

// Tests a while node when the result type T is a vector.
//
// All constants are chosen to produce exact results.
// vector<float> result(8, 0.0f);
// while (result.sum() < 15.5f) {
//   result = result + vector<float>(8, 0.125f);
// }
TEST_F(WhileTest, WhileWithVectorResult) {
  Shape result_shape = ShapeUtil::MakeShape(F32, {8});

  // Create a computation for the reduction.
  Computation add;
  {
    ComputationBuilder builder(client_, "add");
    auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    builder.Add(x, y);
    add = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the condition.
  // Repeat until the sum of the result vector is less than 5.5f.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
                              /*dimensions_to_reduce=*/{0});
    auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body.
  // Add a constant vector of 1.f to the result vector.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto input = builder.ConstantR1<float>(8, 0.125f);
    auto result = builder.Add(input, prev);
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.ConstantR1<float>(8, 0.f);
  auto result = builder.While(condition, body, init);
  VLOG(2) << "while = " << ShapeUtil::HumanString(
                               *builder.GetShape(result).ConsumeValueOrDie());

  // Individual elements with increase by 1/8 each time through the loop, so
  // the sum will increase by 1.0.  It will first be >15.5 when the elements
  // have all reached 2.0.
  std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}

// Tests a while node when the result type T is a Tuple.
//
// tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
// while (get<0>(result) < 5) {
//   get<0>(result) = get<0>(result) + 1;
//   get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
// }
TEST_F(WhileTest, WhileWithTupleResult) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(F32, {10})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Gt(builder.ConstantR0<int32>(5), iteration);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and add a constant vector of 1.0f to
  // the weight variable, both of which are tuple elements.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto input = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, input);
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  auto result = builder.While(condition, body, init);
  VLOG(2) << "while = " << ShapeUtil::HumanString(
                               *builder.GetShape(result).ConsumeValueOrDie());

  auto expected_counter = Literal::CreateR0<int32>(5);
  auto expected_data = Literal::CreateR1<float>(
      {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
  auto expected =
      Literal::MakeTuple({expected_counter.get(), expected_data.get()});
  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}

TEST_F(WhileTest, WhileWithPredicateTupleResult) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(PRED, {})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Gt(builder.ConstantR0<int32>(5), iteration);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and or the predicate with true
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto pred = builder.GetTupleElement(prev, 1);
    auto new_pred = builder.LogicalOr(pred, builder.ConstantR0<bool>(true));
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple({builder.ConstantR0<int32>(0),
                             builder.Ne(builder.ConstantR0<bool>(false),
                                        builder.ConstantR0<bool>(true))});
  auto result = builder.While(condition, body, init);
  VLOG(2) << "while = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(result).ConsumeValueOrDie());

  auto expected_counter = Literal::CreateR0<int32>(5);
  auto expected_predicate = Literal::CreateR0<bool>(true);
  auto expected =
      Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}

// Tests two while nodes when the result type T is a Tuple and the second
// while node uses the result of the first while node which is used in two
// nodes.
// tuple<int32, vector<float>> w0(0, vector<float>(10, 0.0f));
// w0 = while (get<0>(w0) < c1) {
//        get<0>(w0) = get<0>(w0) + 1;
//        get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
//      }
// tuple<int32, vector<float>> w1(get<0>(w0), get<1>(w0));
// w1 = while (get<0>(w1) < c2) {
//        get<0>(w1) = get<0>(w1) + 1;
//        get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
//      }
// result = get<1>(w0) + get<1>(w1)
TEST_F(WhileTest, TwoWhileWithTupleResult) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(F32, {10})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  const int c1 = 5;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build());
  }

  Computation condition2;
  const int c2 = 7;
  {
    ComputationBuilder builder(client_, "condition2");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build());
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and add a constant vector of 1.0f to
  // the weight variable, both of which are tuple elements.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto input = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, input);
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    TF_ASSIGN_OR_ASSERT_OK(body, builder.Build());
  }

  Computation body2;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto input = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, input);
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build());
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  auto while1 = builder.While(condition, body, init);

  auto while2 = builder.While(condition2, body2, while1);

  auto while_result1 = builder.GetTupleElement(while1, 1);
  auto while_result2 = builder.GetTupleElement(while2, 1);
  VLOG(2) << "while_result2 = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(while_result2).ConsumeValueOrDie());
  auto result = builder.Add(while_result1, while_result2);
  VLOG(2) << "result = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(result).ConsumeValueOrDie());
  const float sum = c1 + c2;
  std::vector<float> expected(10, sum);
  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}

// Test while nodes that share the while body computation.
TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(F32, {10})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  const int c1 = 5;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build());
  }

  Computation condition2;
  const int c2 = 7;
  {
    ComputationBuilder builder(client_, "condition2");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build());
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and add a constant vector of 1.0f to
  // the weight variable, both of which are tuple elements.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto input = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, input);
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    TF_ASSIGN_OR_ASSERT_OK(body, builder.Build());
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  auto while1 = builder.While(condition, body, init);

  auto while2 = builder.While(condition2, body, while1);

  auto while_result1 = builder.GetTupleElement(while1, 1);
  auto while_result2 = builder.GetTupleElement(while2, 1);
  VLOG(2) << "while_result2 = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(while_result2).ConsumeValueOrDie());
  auto result = builder.Add(while_result1, while_result2);
  VLOG(2) << "result = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(result).ConsumeValueOrDie());
  const float sum = c1 + c2;
  std::vector<float> expected(10, sum);
  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}

// Test while nodes that share the while body computation.
// TODO(b/37245345): Fails on GPU backend.
TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(F32, {10})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  const int c1 = 5;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build());
  }

  Computation condition2;
  const int c2 = 7;
  {
    ComputationBuilder builder(client_, "condition2");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build());
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and add a constant vector of 1.0f to
  // the weight variable, both of which are tuple elements.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto input = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, input);
    auto result = builder.Tuple(
        {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    TF_ASSIGN_OR_ASSERT_OK(body, builder.Build());
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  auto while1 = builder.While(condition, body, init);
  auto while2 = builder.While(condition2, body, init);

  auto while_result1 = builder.GetTupleElement(while1, 1);
  auto while_result2 = builder.GetTupleElement(while2, 1);
  VLOG(2) << "while_result2 = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(while_result2).ConsumeValueOrDie());
  auto result = builder.Add(while_result1, while_result2);
  VLOG(2) << "result = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(result).ConsumeValueOrDie());
  const float sum = c1 + c2;
  std::vector<float> expected(10, sum);
  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}

// WhileTest that uses DynamicUpdateSlice instruction in body computation.
// Loop state tuple element 1 has as its single user operand(0) of
// DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
  std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                       ShapeUtil::MakeShape(F32, {10})};
  Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);

  // Create a computation for the condition.
  // Repeat for 5 iterations.
  Computation condition;
  {
    ComputationBuilder builder(client_, "condition");
    auto prev = builder.Parameter(0, result_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Gt(builder.ConstantR0<int32>(5), iteration);
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create a computation for the body.
  // Add 1 to the iteration variable and add a constant vector of 1.0f to
  // the weight variable, both of which are tuple elements.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, result_shape, "prev");
    // TupleElement 0
    auto iteration = builder.GetTupleElement(prev, 0);
    auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
    // TupleElement 1
    auto input = builder.GetTupleElement(prev, 1);
    // Update.
    auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32);
    // Starts = iteration * 2;
    auto starts = builder.Reshape(
        builder.Mul(iteration, builder.ConstantR0<int32>(2)), {1});
    // UpdateSlice.
    auto out1 = builder.DynamicUpdateSlice(input, update, starts);

    auto result = builder.Tuple({out0, out1});
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  auto result = builder.While(condition, body, init);
  VLOG(2) << "while = "
          << ShapeUtil::HumanString(
                 *builder.GetShape(result).ConsumeValueOrDie());

  auto expected_counter = Literal::CreateR0<int32>(5);
  auto expected_data = Literal::CreateR1<float>(
      {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
  auto expected =
      Literal::MakeTuple({expected_counter.get(), expected_data.get()});
  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}

// Tests a while node when the result type T is a vector of S32.
//
// int32 result = (0, 0, 0, 0, 0, 0);
// while (result[0] < count) {
//   result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
// }
//
// This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
// pair:
//   ((iteration, (random vector))).
//
// Note: this test currently only tests generating random values within a loop.
// Per backend the values generated can be different as the different backends
// use different random number generators.
// TODO(b/32240857): Extend test to verify outputs.
TEST_F(WhileTest, WhileWithPrngScalarResult) {
  auto v6s32 = ShapeUtil::MakeShape(S32, {6});

  // Create a computation for the condition: repeat for count iterations.
  auto build_condition = [this, v6s32](int count) {
    ComputationBuilder builder(client_, TestName());
    auto prev = builder.Reshape(
        builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
    builder.Gt(builder.ConstantR0<int32>(count), prev);
    return builder.Build().ConsumeValueOrDie();
  };

  // Create a computation for the body: add 1 to the result variable.
  Computation body;
  {
    ComputationBuilder builder(client_, "body");
    auto prev = builder.Parameter(0, v6s32, "prev");
    auto inc = builder.ConcatInDim(
        {builder.ConstantR1<int32>({1}),
         builder.RngUniform(builder.ConstantR0<int32>(0),
                            builder.ConstantR0<int32>(100),
                            ShapeUtil::MakeShape(S32, {5}))},
        0);
    auto result = builder.Add(inc, prev);
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  auto while_loop = [this, &body, build_condition](int count) {
    ComputationBuilder builder(client_, TestName());
    auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
    auto result = builder.While(build_condition(count), body, init);
    auto shape = builder.GetShape(result).ConsumeValueOrDie();
    return builder.Build();
  };

  for (int i = 1; i < 4; ++i) {
    TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i));

    ExecutionOptions execution_options = execution_options_;
    execution_options.set_seed(65);
    TF_ASSIGN_OR_ASSERT_OK(
        auto result,
        client_->ExecuteAndTransfer(computation, {}, &execution_options));
  }
}

// Tests nested while loops.
//
// int32 result = 0;
// while (result < 30) {
//   int i = 0;
//   while (i < 7) {
//     result = result + 2;
//     i = i + 1;
//   }
// }
XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
  auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
  auto inner_result_shape = ShapeUtil::MakeTupleShape(
      {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});

  Computation inner_condition;
  {
    ComputationBuilder builder(client_, "inner_condition");
    auto params = builder.Parameter(0, inner_result_shape, "prev");
    auto i = builder.GetTupleElement(params, 0);
    builder.Lt(i, builder.ConstantR0<int32>(7));
    inner_condition = builder.Build().ConsumeValueOrDie();
  }

  // Creates a computation for the outer loop condition:
  // repeat while result < 30.
  Computation outer_condition;
  {
    ComputationBuilder builder(client_, "outer_condition");
    auto prev = builder.Parameter(0, outer_result_shape, "prev");
    builder.Lt(prev, builder.ConstantR0<int32>(30));
    outer_condition = builder.Build().ConsumeValueOrDie();
  }

  // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
  // `result`.
  Computation inner_body;
  {
    ComputationBuilder builder(client_, "inner_body");
    auto params = builder.Parameter(0, inner_result_shape, "prev");
    auto i = builder.GetTupleElement(params, 0);
    auto result = builder.GetTupleElement(params, 1);
    i = builder.Add(builder.ConstantR0<int32>(1), i);
    result = builder.Add(builder.ConstantR0<int32>(2), result);
    auto output = builder.Tuple({i, result});
    inner_body = builder.Build().ConsumeValueOrDie();
  }

  // Creates a computation for the outer loop: run the inner loop with i = 0.
  Computation outer_body;
  {
    ComputationBuilder builder(client_, "outer_body");
    auto prev = builder.Parameter(0, outer_result_shape, "prev");
    auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
    auto result = builder.While(inner_condition, inner_body, init);
    auto output = builder.GetTupleElement(result, 1);
    outer_body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While node with computations for the condition and the body.
  ComputationBuilder builder(client_, TestName());
  auto init = builder.ConstantR0<int32>(0);
  auto result = builder.While(outer_condition, outer_body, init);
  auto shape = builder.GetShape(result).ConsumeValueOrDie();

  ComputeAndCompareR0<int32>(&builder, 42, {});
}

void BM_WhileLoop(int num_iters) {
  // Benchmark a simple kernel to measure while loop overheads.
  tensorflow::testing::StopTiming();

  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
  StreamExecutorMemoryAllocator allocator(platform, executors);
  LocalClient* client =
      ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();

  Shape loop_state_shape = ShapeUtil::MakeTupleShape(
      {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})});

  // Create while condition computation with 'loop_limit'.
  const int32 loop_limit = 100;
  Computation condition;
  {
    ComputationBuilder builder(client, "condition");
    auto prev = builder.Parameter(0, loop_state_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
    condition = builder.Build().ConsumeValueOrDie();
  }

  // Create while body computation with unit loop increment.
  Computation body;
  {
    ComputationBuilder builder(client, "body");
    auto prev = builder.Parameter(0, loop_state_shape, "prev");
    auto iteration = builder.GetTupleElement(prev, 0);
    auto weights = builder.GetTupleElement(prev, 1);
    auto one = builder.ConstantR0<int32>(1);
    auto next_iteration = builder.Add(iteration, one);
    auto one_vec = builder.ConstantR1<float>(10, 1.f);
    auto new_weights = builder.Add(weights, one_vec);
    auto result = builder.Tuple({next_iteration, new_weights});
    body = builder.Build().ConsumeValueOrDie();
  }

  // Create a While instruction.
  ComputationBuilder builder(client, "while");
  auto init = builder.Tuple(
      {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
  builder.While(condition, body, init);
  auto computation = builder.Build().ConsumeValueOrDie();

  std::unique_ptr<LocalExecutable> executable =
      client->Compile(computation, {}, ExecutableBuildOptions())
          .ConsumeValueOrDie();

  // Run some warm-up executions.
  ExecutableRunOptions options;
  options.set_allocator(&allocator);
  const int kWarmups = 2;
  for (int i = 0; i < kWarmups; ++i) {
    auto result = executable->Run({}, options);
    ASSERT_TRUE(result.ok());
  }

  // Run benchmark.
  tensorflow::testing::StartTiming();
  for (int i = 0; i < num_iters; ++i) {
    auto result = executable->Run({}, options);
    ASSERT_TRUE(result.ok());
  }
}

// TODO(b/32470510): Benchmark fails on parallel CPU backend.
#ifndef XLA_TEST_BACKEND_CPU_PARALLEL
BENCHMARK(BM_WhileLoop);
#endif

}  // namespace
}  // namespace xla

int main(int argc, char** argv) {
  std::vector<tensorflow::Flag> flag_list;
  xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
  xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
  if (!parse_result) {
    LOG(ERROR) << "\n" << usage;
    return 2;
  }
  testing::InitGoogleTest(&argc, argv);
  if (argc > 1) {
    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    return 2;
  }
  tensorflow::testing::RunBenchmarks();
  return RUN_ALL_TESTS();
}