aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/conditional_test.cc
blob: ee3c83039bfc13f6ad78111d92ba0f8387a3ade3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"

namespace xla {
namespace {

class ConditionalOpTest : public ClientLibraryTestBase {
 protected:
  XlaComputation CreateR0ConstantComputation(float value) {
    XlaBuilder builder("Constant");
    Parameter(&builder, 0, empty_tuple_, "tuple");
    ConstantR0<float>(&builder, value);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0IdentityComputation() {
    XlaBuilder builder("Identity");
    Parameter(&builder, 0, r0f32_, "x");
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateCeilComputation(const Shape& shape) {
    XlaBuilder builder("Ceil");
    auto param = Parameter(&builder, 0, shape, "param");
    Ceil(param);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0CeilComputation() {
    return CreateCeilComputation(r0f32_);
  }

  XlaComputation CreateR1CeilComputation() {
    return CreateCeilComputation(r1s2f32_);
  }

  XlaComputation CreateFloorComputation(const Shape& shape) {
    XlaBuilder builder("Floor");
    auto param = Parameter(&builder, 0, shape, "param");
    Floor(param);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0FloorComputation() {
    return CreateFloorComputation(r0f32_);
  }

  XlaComputation CreateR1FloorComputation() {
    return CreateFloorComputation(r1s2f32_);
  }

  XlaComputation CreateTupleCeilComputation(const string& computation_name,
                                            const Shape& tuple_shape) {
    XlaBuilder builder(computation_name);
    auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
    auto x = GetTupleElement(tuple, 0);
    auto y = GetTupleElement(tuple, 1);
    auto x_ceil = Ceil(x);
    auto y_ceil = Ceil(y);
    Tuple(&builder, {x_ceil, y_ceil});
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0TupleCeilComputation() {
    return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
  }

  XlaComputation CreateR1TupleCeilComputation() {
    return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
  }

  XlaComputation CreateTupleFloorComputation(const string& computation_name,
                                             const Shape& tuple_shape) {
    XlaBuilder builder(computation_name);
    auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
    auto x = GetTupleElement(tuple, 0);
    auto y = GetTupleElement(tuple, 1);
    auto x_floor = Floor(x);
    auto y_floor = Floor(y);
    Tuple(&builder, {x_floor, y_floor});
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0TupleFloorComputation() {
    return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
  }

  XlaComputation CreateR1TupleFloorComputation() {
    return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
  }

  XlaComputation CreateTupleAddComputation(const string& computation_name,
                                           const Shape& tuple_shape) {
    XlaBuilder builder(computation_name);
    auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
    auto x = GetTupleElement(tuple, 0);
    auto y = GetTupleElement(tuple, 1);
    Add(x, y);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0TupleAddComputation() {
    return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
  }

  XlaComputation CreateR1TupleAddComputation() {
    return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
  }

  XlaComputation CreateTupleSubComputation(const string& computation_name,
                                           const Shape& tuple_shape) {
    XlaBuilder builder(computation_name);
    auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
    auto x = GetTupleElement(tuple, 0);
    auto y = GetTupleElement(tuple, 1);
    Sub(x, y);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0TupleSubComputation() {
    return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
  }

  XlaComputation CreateR1TupleSubComputation() {
    return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
  }

  Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
  Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
  Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
      {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
  Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
      {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
  Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
  ErrorSpec error_spec_{0.001};
};

// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operands = Tuple(&builder, {});
  auto true_computation = CreateR0ConstantComputation(56.0f);
  auto false_computation = CreateR0ConstantComputation(12.0f);
  Conditional(pred, operands, true_computation, operands, false_computation);

  ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_);
}

// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.0f);
  auto operand2 = ConstantR0<float>(&builder, 12.0f);
  auto identity = CreateR0IdentityComputation();
  Conditional(pred, operand1, identity, operand2, identity);

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.4f);
  auto operand2 = ConstantR0<float>(&builder, 12.6f);
  Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
              CreateR0FloorComputation());

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test conditional with two different computations in the true and false cases
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand = ConstantR0<float>(&builder, 12.6f);
  Conditional(pred, operand, CreateR0CeilComputation(), operand,
              CreateR0FloorComputation());

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test conditional with the same computation in the true and false cases but
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.4f);
  auto operand2 = ConstantR0<float>(&builder, 12.6f);
  auto floor = CreateR0FloorComputation();
  Conditional(pred, operand1, floor, operand2, floor);

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test conditional with the same computation in the true and false cases that
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand = ConstantR0<float>(&builder, 12.6f);
  auto floor = CreateR0FloorComputation();
  Conditional(pred, operand, floor, operand, floor);

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test conditional with different instances of the same computation in the true
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.4f);
  auto operand2 = ConstantR0<float>(&builder, 12.6f);
  Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
              CreateR0FloorComputation());

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
  Shape r0bool = ShapeUtil::MakeShape(PRED, {});
  XlaBuilder inner_builder(TestName() + ".inner_conditional");
  auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
  auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
  auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
  Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
              CreateR0FloorComputation());
  auto inner_builder_result = inner_builder.Build();

  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.4f);
  auto operand2 = ConstantR0<float>(&builder, 12.6f);
  Call(&builder, inner_builder_result.ConsumeValueOrDie(),
       {pred, operand1, operand2});

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test true and false computations that take in 2 parameters and predicate is
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operand1 = ConstantR0<float>(&builder, 56.0f);
  auto operand2 = ConstantR0<float>(&builder, 12.0f);
  auto operands = Tuple(&builder, {operand1, operand2});
  Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
              CreateR0TupleSubComputation());

  ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
}

// Test true and false computations that take in 2 parameters and predicate is
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 56.0f);
  auto operand2 = ConstantR0<float>(&builder, 12.0f);
  auto operands = Tuple(&builder, {operand1, operand2});
  Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
              CreateR0TupleSubComputation());

  ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
}

// Test true and false computations that take in 2 array parameters and
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
  auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
  auto operands = Tuple(&builder, {operand1, operand2});
  Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
              CreateR1TupleSubComputation());

  ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
}

// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
  auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
  auto operands = Tuple(&builder, {operand1, operand2});
  Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
              CreateR1TupleSubComputation());

  ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
}

// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
                                   ConstantR0<float>(&builder, 25.6f)});
  Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
              CreateR0TupleFloorComputation());

  ComputeAndCompareTuple(
      &builder,
      *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
                           Literal::CreateR0<float>(25.0f).get()}),
      {}, error_spec_);
}

// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operands =
      Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
                       ConstantR1<float>(&builder, {25.6f, 29.2f})});
  Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
              CreateR1TupleFloorComputation());

  ComputeAndCompareTuple(
      &builder,
      *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
                           Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
      {}, error_spec_);
}

// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
  XlaBuilder true_builder(TestName() + ".true");
  {
    Parameter(&true_builder, 0, empty_tuple_, "tuple");
    auto true_pred = ConstantR0<bool>(&true_builder, true);
    auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
    auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
    Tuple(&true_builder, {true_pred, true_scalar, true_array});
  }
  auto true_builder_result = true_builder.Build();
  EXPECT_IS_OK(true_builder_result.status());

  XlaBuilder false_builder(TestName() + ".false");
  {
    Parameter(&false_builder, 0, empty_tuple_, "tuple");
    auto false_pred = ConstantR0<bool>(&false_builder, false);
    auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
    auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
    Tuple(&false_builder, {false_pred, false_scalar, false_array});
  }
  auto false_builder_result = false_builder.Build();
  EXPECT_IS_OK(false_builder_result.status());

  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operands = Tuple(&builder, {});
  Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
              false_builder_result.ConsumeValueOrDie());

  ComputeAndCompareTuple(
      &builder,
      *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
                           Literal::CreateR0<float>(12.2f).get(),
                           Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
      {}, error_spec_);
}

// Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
  XlaBuilder true_builder(TestName() + ".true");
  {
    Parameter(&true_builder, 0, empty_tuple_, "tuple");
    auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
    auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
    auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
    auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
    Tuple(&true_builder,
          {Tuple(&true_builder, {true_constant1, true_constant2}),
           Tuple(&true_builder, {true_constant3, true_constant4})});
  }
  auto true_builder_result = true_builder.Build();
  EXPECT_IS_OK(true_builder_result.status());

  XlaBuilder false_builder(TestName() + ".false");
  {
    Parameter(&false_builder, 0, empty_tuple_, "tuple");
    auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
    auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
    auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
    auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
    Tuple(&false_builder,
          {Tuple(&false_builder, {false_constant1, false_constant2}),
           Tuple(&false_builder, {false_constant3, false_constant4})});
  }
  auto false_builder_result = false_builder.Build();
  EXPECT_IS_OK(false_builder_result.status());

  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, false);
  auto operands = Tuple(&builder, {});
  Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
              false_builder_result.ConsumeValueOrDie());

  ComputeAndCompareTuple(
      &builder,
      *Literal::MakeTuple(
          {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
                               Literal::CreateR1<float>({54.4f, 58.4f}).get()})
               .get(),
           Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
                               Literal::CreateR0<float>(9.3f).get()})
               .get()}),
      {}, error_spec_);
}

// Test conditional that takes in scalar operands in the form of external
// params.
XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
  Shape r0bool = ShapeUtil::MakeShape(PRED, {});
  XlaBuilder builder(TestName());

  XlaOp pred, operand1, operand2;
  auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
  auto operand1_param =
      CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
  auto operand2_param =
      CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
  Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
              CreateR0FloorComputation());

  ComputeAndCompareR0<float>(
      &builder, 57.0f,
      {pred_arg.get(), operand1_param.get(), operand2_param.get()},
      error_spec_);
}

// Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
  Shape r0bool = ShapeUtil::MakeShape(PRED, {});
  XlaBuilder builder(TestName());

  XlaOp pred, operand1, operand2;
  auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
  auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
                                                 &builder, &operand1);
  auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
                                                 &builder, &operand2);
  Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
              CreateR1FloorComputation());

  ComputeAndCompareR1<float>(
      &builder, {10.0f, 11.0f},
      {pred_arg.get(), operand1_param.get(), operand2_param.get()},
      error_spec_);
}

// Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
  XlaBuilder inner_builder(TestName() + ".inner_conditional");
  {
    Shape r0bool = ShapeUtil::MakeShape(PRED, {});
    Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
    auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
    auto pred_cond = GetTupleElement(param0, 0);
    auto true_operand = GetTupleElement(param0, 1);
    auto false_operand = GetTupleElement(param0, 2);
    Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
                false_operand, CreateR0FloorComputation());
  }
  auto inner_builder_result = inner_builder.Build();
  EXPECT_IS_OK(inner_builder_result.status());

  XlaBuilder builder(TestName());
  auto pred1 = ConstantR0<bool>(&builder, true);
  auto pred2 = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 1.1f);
  auto operand2 = ConstantR0<float>(&builder, 12.2f);
  auto operand3 = ConstantR0<float>(&builder, 43.3f);
  auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
  Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
              operand3, CreateR0IdentityComputation());

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
  XlaBuilder inner_builder(TestName() + ".inner_conditional");
  {
    Shape r0bool = ShapeUtil::MakeShape(PRED, {});
    Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
    auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
    auto pred_cond = GetTupleElement(param0, 0);
    auto true_operand = GetTupleElement(param0, 1);
    auto false_operand = GetTupleElement(param0, 2);
    Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
                false_operand, CreateR0FloorComputation());
  }
  auto inner_builder_result = inner_builder.Build();
  EXPECT_IS_OK(inner_builder_result.status());

  XlaBuilder builder(TestName());
  auto pred2 = ConstantR0<bool>(&builder, false);
  auto operand1 = ConstantR0<float>(&builder, 1.1f);
  auto operand2 = ConstantR0<float>(&builder, 12.2f);
  auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
  Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});

  ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}

// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
  XlaBuilder builder(TestName());
  auto pred = ConstantR0<bool>(&builder, true);
  auto operand1 = ConstantR0<float>(&builder, 56.0f);
  auto operand2 = ConstantR0<float>(&builder, 12.0f);
  auto operands = Tuple(&builder, {operand1, operand2});
  Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
              CreateR0TupleSubComputation());

  auto result = builder.Build();
  EXPECT_FALSE(result.ok());
  EXPECT_THAT(result.status().error_message(),
              ::testing::HasSubstr("true_operand must match the shape of the "
                                   "only parameter of true_computation"));
}

XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
  Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
  XlaComputation swapper;
  {
    XlaBuilder builder(TestName() + ".swapper");
    auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
    auto x = GetTupleElement(param0, 0);
    auto y = GetTupleElement(param0, 1);
    Tuple(&builder, {y, x});
    swapper = builder.Build().ConsumeValueOrDie();
  }
  XlaComputation forwarder;
  {
    XlaBuilder builder(TestName() + ".forwarder");
    auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
    auto x = GetTupleElement(param0, 0);
    auto y = GetTupleElement(param0, 1);
    Tuple(&builder, {x, y});
    forwarder = builder.Build().ConsumeValueOrDie();
  }
  XlaComputation main;
  {
    XlaBuilder builder(TestName() + ".main");
    auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
    auto x = GetTupleElement(param0, 0);
    auto y = GetTupleElement(param0, 1);
    auto lt_pred = Lt(x, y);
    auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
    auto ge_pred = Ge(x, y);
    Conditional(ge_pred, res, swapper, res, forwarder);
    main = builder.Build().ConsumeValueOrDie();
  }

  auto test_swap = [&](float a, float b) {
    XlaBuilder builder(TestName());
    auto x = ConstantR0<float>(&builder, a);
    auto y = ConstantR0<float>(&builder, b);
    auto tuple_operand = Tuple(&builder, {x, y});
    Call(&builder, main, {tuple_operand});

    ComputeAndCompareTuple(
        &builder,
        *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
                             Literal::CreateR0<float>(b).get()}),
        {}, error_spec_);
  };

  test_swap(3.11f, 9.4f);
  test_swap(11.24f, 5.55f);
}

}  // namespace
}  // namespace xla