aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
blob: ae3f887240d0ccffcc9c51a2c409de457a94f967 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <numeric>
#include <vector>

#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"

namespace se = ::perftools::gputools;

namespace xla {
namespace {

class DynamicSliceTest : public ClientLibraryTestBase {
 protected:
  template <typename IndexT, typename DataT>
  void TestR1() {
    // Slice at dimension start.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4});
    // Slice in the middle.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4});
    // Slice at dimension boundaries.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7});
    // Zero element slice.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {});
  }

  template <typename IndexT, typename DataT>
  void TestR1Wrap() {
    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1});
  }

  template <typename IndexT, typename DataT>
  void TestR2() {
    // Slice at dimension start.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2},
                         {{1, 2}, {4, 5}});
    // Slice in the middle.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
                         {{5}, {8}});
    // Slice at dimension boundaries.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
                         {{5}, {8}});
    // Zero element slice: 2x0.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0},
                         {{}, {}});
    // Zero element slice: 0x2.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2},
                         Array2D<int>(0, 2));
  }

  template <typename IndexT, typename DataT>
  void TestR2Wrap() {
    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
                         {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}});
  }

  template <typename IndexT, typename DataT>
  void TestR3() {
    // R3 Shape: [2, 3, 2]
    // clang-format off

    // Slice at dimension start.
    RunR3<IndexT, DataT>(
      {{{1, 2}, {3, 4}, {5, 6}},
       {{7, 8}, {9, 10}, {11, 12}}},
      {0, 0, 0}, {2, 1, 2},
      {{{1, 2}}, {{7, 8}}});

    // Slice in the middle.
    RunR3<IndexT, DataT>(
      {{{1, 2}, {3, 4}, {5, 6}},
       {{7, 8}, {9, 10}, {11, 12}}},
      {0, 1, 1}, {2, 2, 1},
      {{{4}, {6}}, {{10}, {12}}});
    // clang-format on
  }

  template <typename IndexT, typename DataT>
  void TestR3Wrap() {
    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
    RunR3<IndexT, DataT>(
      {{{1, 2}, {3, 4}, {5, 6}},
       {{7, 8}, {9, 10}, {11, 12}}},
      {0, 2, 1}, {2, 1, 2},
      {{{6, 5}}, {{12, 11}}});
  }

  template <typename IndexT, typename DataT>
  void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
             const std::vector<IndexT> slice_starts,
             const std::vector<int64>& slice_sizes,
             tensorflow::gtl::ArraySlice<int> expected_values_int) {
    // bfloat16 has explicit constructors, so it does not implicitly convert the
    // way built-in types do, which is why we can't take the parameter as an
    // ArraySlice<DataT>. We also can't convert it to a vector, because
    // vector<bool> is special so that it cannot be an ArraySlice<bool>, which
    // is what the code below wants. So instead we do this.
    Literal input_values =
        std::move(*Literal::CreateR1(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR1(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    builder.DynamicSlice(input, starts, slice_sizes);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }

  template <typename IndexT, typename DataT>
  void RunR2(const Array2D<int>& input_values_int,
             const std::vector<IndexT> slice_starts,
             const std::vector<int64>& slice_sizes,
             const Array2D<int>& expected_values_int) {
    Literal input_values =
        std::move(*Literal::CreateR2FromArray2D(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR2FromArray2D(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    builder.DynamicSlice(input, starts, slice_sizes);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }

  template <typename IndexT, typename DataT>
  void RunR3(const Array3D<int>& input_values_int,
             const std::vector<IndexT> slice_starts,
             const std::vector<int64>& slice_sizes,
             const Array3D<int>& expected_values_int) {
    Literal input_values =
        std::move(*Literal::CreateR3FromArray3D(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR3FromArray3D(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    builder.DynamicSlice(input, starts, slice_sizes);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }
};

XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, double>(); }

XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, double>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }

XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, double>(); }

XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
  // Slice at dimension start.
  RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {0},
                     {5}, {true, false, false, true, false});
  // Slice in the middle.
  RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {2},
                     {3}, {false, true, false});
  // Slice at dimension boundaries.
  RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {5},
                     {3}, {true, true, false});
  // Zero element slice.
  RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {2},
                     {0}, {});
}

XLA_TEST_F(DynamicSliceTest, Int32R2Pred) {
  // Slice at dimension start.
  RunR2<int32, bool>(
      {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
      {2, 2}, {{true, false}, {false, false}});
  // Slice in the middle.
  RunR2<int32, bool>(
      {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
      {2, 1}, {{false}, {true}});
  // Slice at dimension boundaries.
  RunR2<int32, bool>(
      {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
      {2, 1}, {{false}, {true}});
  // Zero element slice: 2x0.
  RunR2<int32, bool>(
      {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
      {2, 0}, {{}, {}});
  // Zero element slice: 0x2.
  RunR2<int32, bool>(
      {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
      {0, 2}, Array2D<int>(0, 2));
}

XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
  // R3 Shape: [2, 3, 2]
  // clang-format off

  // Slice at dimension start.
  RunR3<int32, bool>(
    {{{true, false}, {false, true}, {true, true}},
     {{false, true}, {true, false}, {false, false}}},
    {0, 0, 0}, {2, 1, 2},
    {{{true, false}}, {{false, true}}});

  // Slice in the middle.
  RunR3<int32, bool>(
    {{{true, false}, {false, true}, {true, true}},
     {{false, true}, {true, false}, {false, false}}},
    {0, 1, 1}, {2, 2, 1},
    {{{true}, {true}}, {{false}, {false}}});

  // clang-format on
}

class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 protected:
  template <typename IndexT, typename DataT>
  void TestR1() {
    // Slice at dimension start.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
                         {8, 9, 10, 3, 4, 5, 6, 7});
    // Slice in the middle.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2},
                         {0, 1, 8, 9, 10, 5, 6, 7});
    // Slice at dimension boundaries.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5},
                         {0, 1, 2, 3, 4, 8, 9, 10});
    // Zero-sized update.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2},
                         {0, 1, 2, 3, 4, 5, 6, 7});
  }

  template <typename IndexT, typename DataT>
  void TestR2() {
    // Slice at dimension start.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0},
                         {{10, 11, 3}, {4, 5, 6}, {7, 8, 9}});
    // Slice in the middle.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1},
                         {{1, 2, 3}, {4, 10, 11}, {7, 8, 9}});
    // Slice at dimension boundaries.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1},
                         {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
    // Zero-sized update.
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1},
                         {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
  }

  template <typename IndexT, typename DataT>
  void TestR3() {
    // R3 Shape: [2, 3, 2]
    // Slice at dimension start.
    RunR3<IndexT, DataT>(
        {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}},
        {{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0},
        {{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}});
    // Slice in the middle.
    RunR3<IndexT, DataT>(
        {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
        {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
  }

  template <typename IndexT, typename DataT>
  void TestWrap() {
    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
                         {10, 1, 2, 3, 4, 5, 8, 9});
    // R2 Shape: [3, 3]
    RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
                         {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}});
    // R3 Shape: [2, 3, 2]
    RunR3<IndexT, DataT>(
        {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
        {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}});
  }

  template <typename IndexT, typename DataT>
  void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
             tensorflow::gtl::ArraySlice<int> update_values_int,
             const std::vector<IndexT> slice_starts,
             tensorflow::gtl::ArraySlice<int> expected_values_int) {
    Literal input_values =
        std::move(*Literal::CreateR1(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal update_values =
        std::move(*Literal::CreateR1(update_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR1(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    auto update = builder.ConstantLiteral(update_values);
    builder.DynamicUpdateSlice(input, update, starts);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }

  template <typename IndexT, typename DataT>
  void RunR2(const Array2D<int>& input_values_int,
             const Array2D<int>& update_values_int,
             const std::vector<IndexT> slice_starts,
             const Array2D<int>& expected_values_int) {
    Literal input_values =
        std::move(*Literal::CreateR2FromArray2D(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal update_values =
        std::move(*Literal::CreateR2FromArray2D(update_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR2FromArray2D(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    auto update = builder.ConstantLiteral(update_values);
    builder.DynamicUpdateSlice(input, update, starts);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }

  template <typename IndexT, typename DataT>
  void RunR3(const Array3D<int>& input_values_int,
             const Array3D<int>& update_values_int,
             const std::vector<IndexT> slice_starts,
             const Array3D<int>& expected_values_int) {
    Literal input_values =
        std::move(*Literal::CreateR3FromArray3D(input_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal update_values =
        std::move(*Literal::CreateR3FromArray3D(update_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());
    Literal expected_values =
        std::move(*Literal::CreateR3FromArray3D(expected_values_int)
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
                       .ValueOrDie());

    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer dynamic slice start indices parameter.
    ComputationDataHandle starts;
    std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
        slice_starts, 0, "slice_starts", &builder, &starts);
    // Build dynamic slice computation.
    auto input = builder.ConstantLiteral(input_values);
    auto update = builder.ConstantLiteral(update_values);
    builder.DynamicUpdateSlice(input, update, starts);
    // Run computation and compare against expected values.
    ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
  }

  template <class T>
  void RunR3Contiguous(std::vector<int32> operand_shape, int32 index,
                       int32 size) {
#ifdef XLA_TEST_BACKEND_CPU_PARALLEL
    // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
    if (std::is_same<bfloat16, T>::value) {
      return;
    }
#endif

    const int32 kSeq = operand_shape[0];
    const int32 kBatch = operand_shape[1];
    const int32 kDim = operand_shape[2];
    Array3D<T> input_values(kSeq, kBatch, kDim);
    Array3D<T> update_values(size, kBatch, kDim);
    Array3D<T> expected_values(kSeq, kBatch, kDim);

    input_values.FillIota(static_cast<T>(0));
    T value = static_cast<T>(10);
    update_values.FillIota(static_cast<T>(value));

    // TODO(b/34128753) Expected values may vary depending on backend when
    // the update wraps. According to documentation, the results are technically
    // implementation specific where the update is out of bounds, and hence
    // we don't really know what to pass into ComputeAndCompareR3.
    expected_values.FillIota(static_cast<T>(0));
    for (int i = 0; i < size; i++) {
      for (int j = 0; j < kBatch; j++) {
        for (int k = 0; k < kDim; k++) {
          expected_values((index + i) % kSeq, j, k) = value++;
        }
      }
    }
    if (VLOG_IS_ON(1)) {
      DumpArray<T>("input", input_values);
      DumpArray<T>("update", update_values);
      DumpArray<T>("expected", expected_values);
    }

    // Build dynamic slice computation.
    ComputationBuilder builder(client_, TestName());
    // Initialize and transfer input parameter.
    ComputationDataHandle input;
    std::unique_ptr<GlobalData> input_data =
        CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
    // Initialize and transfer update parameter.
    ComputationDataHandle update;
    std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
        update_values, 1, "update_values", &builder, &update);
    auto starts = builder.ConstantR1<int32>({index, 0, 0});
    builder.DynamicUpdateSlice(input, update, starts);

    // Run computation and compare against expected values.
    ComputeAndCompareR3<T>(&builder, expected_values,
                           {input_data.get(), update_data.get()},
                           ErrorSpec(0.000001));
  }

  template <typename NativeT>
  void DumpArray(const string& name, const Array3D<NativeT> values) {
    std::unique_ptr<Literal> literal =
        Literal::CreateR3FromArray3D<NativeT>(values);
    LOG(INFO) << name << ":" << literal->ToString();
  }
};

// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) {
  TestR1<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64, double>(); }

// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) {
  TestR2<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64, int32>(); }

// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) {
  TestR3<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }

XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) {
  TestWrap<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap<uint64, uint64>(); }

XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
  // Slice at dimension start.
  RunR1<int32, bool>({false, false, true, true, false, true, true, false},
                     {true, true, false}, {0},
                     {true, true, false, true, false, true, true, false});
  // Slice in the middle.
  RunR1<int32, bool>({false, false, true, true, false, true, true, false},
                     {false, true, true}, {2},
                     {false, false, false, true, true, true, true, false});
  // Slice at dimension boundaries.
  RunR1<int32, bool>({false, false, true, true, false, true, true, false},
                     {false, true, true}, {5},
                     {false, false, true, true, false, false, true, true});
  // Zero-sized update.
  RunR1<int32, bool>({false, false, true, true, false, true, true, false}, {},
                     {2}, {false, false, true, true, false, true, true, false});
}

XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) {
  // Slice at dimension start.
  RunR2<int32, bool>(
      {{false, true, false}, {true, false, true}, {false, true, true}},
      {{true, false}}, {0, 0},
      {{true, false, false}, {true, false, true}, {false, true, true}});
  // Slice in the middle.
  RunR2<int32, bool>(
      {{false, true, false}, {true, false, true}, {false, true, true}},
      {{true, false}}, {1, 1},
      {{false, true, false}, {true, true, false}, {false, true, true}});
  // Slice at dimension boundaries.
  RunR2<int32, bool>(
      {{false, true, false}, {true, false, true}, {false, true, true}},
      {{true, false}}, {2, 1},
      {{false, true, false}, {true, false, true}, {false, true, false}});
  // Zero-sized update.
  RunR2<int32, bool>(
      {{false, true, false}, {true, false, true}, {false, true, true}}, {{}},
      {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}});
}

XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
  // R3 Shape: [2, 3, 2]
  // Slice at dimension start.
  RunR3<int32, bool>(
      {{{true, false}, {false, true}, {true, true}},
       {{false, false}, {false, true}, {true, false}}},
      {{{false, true}, {true, false}}, {{true, true}, {false, true}}},
      {0, 0, 0},
      {{{false, true}, {true, false}, {true, true}},
       {{true, true}, {false, true}, {true, false}}});
  // Slice in the middle.
  RunR3<int32, bool>({{{true, false}, {false, true}, {true, true}},
                      {{false, false}, {false, true}, {true, false}}},
                     {{{false}, {true}}}, {1, 1, 1},
                     {{{true, false}, {false, true}, {true, true}},
                      {{false, false}, {false, false}, {true, true}}});
}

// Tests for simple R3 case where the update is contiguous (i.e. the minor
// two dimensions are not sliced).
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
  // Single element, no wrap.
  std::vector<int32> operand_shape({4, 5, 2});
  RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}

XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
  // Multiple element, no wrap.
  std::vector<int32> operand_shape({4, 5, 2});
  RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
}

XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) {
  // Multiple element, wrapping.
  std::vector<int32> operand_shape({4, 5, 2});
  RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
}

XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) {
  // Multiple element, update size larger than operand.
  std::vector<int32> operand_shape({4, 5, 2});
  RunR3Contiguous<float>(operand_shape, /*index=*/5, /*size=*/2);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/5, /*size=*/2);
}

XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
  std::vector<int32> operand_shape({3, 123, 247});
  RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}

// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
  std::vector<int32> operand_shape({32, 128, 1024});
  RunR3Contiguous<float>(operand_shape, /*index=*/7, /*size=*/1);
  RunR3Contiguous<bfloat16>(operand_shape, /*index=*/7, /*size=*/1);
}

void BM_DynamicSlice(int num_iters) {
  tensorflow::testing::StopTiming();

  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
  StreamExecutorMemoryAllocator allocator(platform, executors);
  LocalClient* client =
      ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
  auto* transfer_manager =
      TransferManager::GetForPlatform(platform).ValueOrDie();
  int device_ordinal = client->default_device_ordinal();

  ComputationBuilder builder(client, "DynamicSlice");

  // Create input as a constant: shape [1, 2, 3, 4]
  auto input_literal = Literal::CreateR4(
      {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
        {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
  auto input = builder.ConstantLiteral(*input_literal);

  // Create dynamic slice start indices as a parameter: shape [4]
  auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
  auto start_indices =
      builder.Parameter(0, start_indices_shape, "start_indices");
  // Add DynamicSlice op to the computatation.
  builder.DynamicSlice(input, start_indices, {1, 1, 1, 1});
  auto computation = builder.Build().ConsumeValueOrDie();

  // Initialize and transfer parameter buffer.
  auto buffer = client->backend()
                    .transfer_manager()
                    ->AllocateScopedShapedBuffer(
                        start_indices_shape, &allocator, /*device_ordinal=*/0)
                    .ConsumeValueOrDie();

  auto start_indices_literal = Literal::CreateR1<int32>({0, 1, 2, 3});
  ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
      executors[device_ordinal], *start_indices_literal, *buffer));

  std::unique_ptr<LocalExecutable> executable =
      client
          ->Compile(computation, {&buffer->on_host_shape()},
                    ExecutableBuildOptions())
          .ConsumeValueOrDie();

  // Run some warm-up executions.
  ExecutableRunOptions options;
  options.set_allocator(&allocator);
  const int kWarmups = 2;
  for (int i = 0; i < kWarmups; ++i) {
    auto result = executable->Run({buffer.get()}, options);
    ASSERT_TRUE(result.ok());
  }

  // Run benchmark.
  tensorflow::testing::StartTiming();
  for (int i = 0; i < num_iters; ++i) {
    auto result = executable->Run({buffer.get()}, options);
    ASSERT_TRUE(result.ok());
  }
}
BENCHMARK(BM_DynamicSlice);

}  // namespace
}  // namespace xla