aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-08 16:38:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 16:41:51 -0800
commit8398ece2cf59c6a4596b3bdf79d7e8553b953afa (patch)
treefe74d7f3195b8bcca26fd99420369d7749d707f4
parentfffcacb1d0537a353c4e9303f98d1e9b2e83bf23 (diff)
[XLA] Add tests for conditionals that return tuples and whose predicate operands are external parameters.
PiperOrigin-RevId: 181237427
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc334
1 files changed, 281 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index c8c4932be8..0016b6cc61 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -23,7 +23,7 @@ namespace {
class ConditionalOpTest : public ClientLibraryTestBase {
protected:
- Computation CreateR0F32ConstantComputation(float value) {
+ Computation CreateR0ConstantComputation(float value) {
ComputationBuilder builder(client_, "Constant");
builder.Parameter(0, empty_tuple_, "tuple");
builder.ConstantR0<float>(value);
@@ -32,7 +32,7 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0F32IdentityComputation() {
+ Computation CreateR0IdentityComputation() {
ComputationBuilder builder(client_, "Identity");
builder.Parameter(0, r0f32_, "x");
auto build_status = builder.Build();
@@ -40,25 +40,85 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0F32CeilComputation() {
+ Computation CreateCeilComputation(const Shape& shape) {
ComputationBuilder builder(client_, "Ceil");
- auto param = builder.Parameter(0, r0f32_, "param");
+ auto param = builder.Parameter(0, shape, "param");
builder.Ceil(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0F32FloorComputation() {
- ComputationBuilder builder(client_, "Ceil");
- auto param = builder.Parameter(0, r0f32_, "param");
+ Computation CreateR0CeilComputation() {
+ return CreateCeilComputation(r0f32_);
+ }
+
+ Computation CreateR1CeilComputation() {
+ return CreateCeilComputation(r1s2f32_);
+ }
+
+ Computation CreateFloorComputation(const Shape& shape) {
+ ComputationBuilder builder(client_, "Floor");
+ auto param = builder.Parameter(0, shape, "param");
builder.Floor(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
- Computation CreateAddTupleComputation(const string& computation_name,
+ Computation CreateR0FloorComputation() {
+ return CreateFloorComputation(r0f32_);
+ }
+
+ Computation CreateR1FloorComputation() {
+ return CreateFloorComputation(r1s2f32_);
+ }
+
+ Computation CreateTupleCeilComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ ComputationBuilder builder(client_, computation_name);
+ auto tuple = builder.Parameter(0, tuple_shape, "tuple");
+ auto x = builder.GetTupleElement(tuple, 0);
+ auto y = builder.GetTupleElement(tuple, 1);
+ auto x_ceil = builder.Ceil(x);
+ auto y_ceil = builder.Ceil(y);
+ builder.Tuple({x_ceil, y_ceil});
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR0TupleCeilComputation() {
+ return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
+ }
+
+ Computation CreateR1TupleCeilComputation() {
+ return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
+ }
+
+ Computation CreateTupleFloorComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ ComputationBuilder builder(client_, computation_name);
+ auto tuple = builder.Parameter(0, tuple_shape, "tuple");
+ auto x = builder.GetTupleElement(tuple, 0);
+ auto y = builder.GetTupleElement(tuple, 1);
+ auto x_floor = builder.Floor(x);
+ auto y_floor = builder.Floor(y);
+ builder.Tuple({x_floor, y_floor});
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR0TupleFloorComputation() {
+ return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
+ }
+
+ Computation CreateR1TupleFloorComputation() {
+ return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
+ }
+
+ Computation CreateTupleAddComputation(const string& computation_name,
const Shape& tuple_shape) {
ComputationBuilder builder(client_, computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
@@ -70,15 +130,15 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateAddR0Computation() {
- return CreateAddTupleComputation("AddR0", tuple_2_r0f32_);
+ Computation CreateR0TupleAddComputation() {
+ return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
}
- Computation CreateAddR1Computation() {
- return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_);
+ Computation CreateR1TupleAddComputation() {
+ return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
}
- Computation CreateSubTupleComputation(const string& computation_name,
+ Computation CreateTupleSubComputation(const string& computation_name,
const Shape& tuple_shape) {
ComputationBuilder builder(client_, computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
@@ -90,15 +150,16 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateSubR0Computation() {
- return CreateSubTupleComputation("SubR0", tuple_2_r0f32_);
+ Computation CreateR0TupleSubComputation() {
+ return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
}
- Computation CreateSubR1Computation() {
- return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_);
+ Computation CreateR1TupleSubComputation() {
+ return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
}
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
@@ -112,8 +173,8 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({});
- auto true_computation = CreateR0F32ConstantComputation(56.0f);
- auto false_computation = CreateR0F32ConstantComputation(12.0f);
+ auto true_computation = CreateR0ConstantComputation(56.0f);
+ auto false_computation = CreateR0ConstantComputation(12.0f);
auto result = builder.Conditional(pred, operands, true_computation, operands,
false_computation);
@@ -126,7 +187,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) {
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
- auto identity = CreateR0F32IdentityComputation();
+ auto identity = CreateR0IdentityComputation();
auto result =
builder.Conditional(pred, operand1, identity, operand2, identity);
@@ -140,9 +201,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
- auto result =
- builder.Conditional(pred, operand1, CreateR0F32CeilComputation(),
- operand2, CreateR0F32FloorComputation());
+ auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
+ operand2, CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -153,8 +213,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand = builder.ConstantR0<float>(12.6f);
- auto result = builder.Conditional(pred, operand, CreateR0F32CeilComputation(),
- operand, CreateR0F32FloorComputation());
+ auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(),
+ operand, CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -166,7 +226,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
- auto floor = CreateR0F32FloorComputation();
+ auto floor = CreateR0FloorComputation();
auto result = builder.Conditional(pred, operand1, floor, operand2, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
@@ -178,7 +238,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand = builder.ConstantR0<float>(12.6f);
- auto floor = CreateR0F32FloorComputation();
+ auto floor = CreateR0FloorComputation();
auto result = builder.Conditional(pred, operand, floor, operand, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
@@ -191,9 +251,8 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
- auto result =
- builder.Conditional(pred, operand1, CreateR0F32FloorComputation(),
- operand2, CreateR0F32FloorComputation());
+ auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(),
+ operand2, CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -205,9 +264,8 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
auto pred_cond = inner_builder.Parameter(0, r0bool, "param0");
auto true_operand = inner_builder.Parameter(1, r0f32_, "param1");
auto false_operand = inner_builder.Parameter(2, r0f32_, "param2");
- inner_builder.Conditional(pred_cond, true_operand,
- CreateR0F32CeilComputation(), false_operand,
- CreateR0F32FloorComputation());
+ inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
+ false_operand, CreateR0FloorComputation());
auto inner_builder_result = inner_builder.Build();
ComputationBuilder builder(client_, TestName());
@@ -228,8 +286,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto operands = builder.Tuple({operand1, operand2});
- auto result = builder.Conditional(pred, operands, CreateAddR0Computation(),
- operands, CreateSubR0Computation());
+ auto result =
+ builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
+ operands, CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
}
@@ -242,8 +301,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto operands = builder.Tuple({operand1, operand2});
- auto result = builder.Conditional(pred, operands, CreateAddR0Computation(),
- operands, CreateSubR0Computation());
+ auto result =
+ builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
+ operands, CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
}
@@ -256,8 +316,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
auto operands = builder.Tuple({operand1, operand2});
- auto result = builder.Conditional(pred, operands, CreateAddR1Computation(),
- operands, CreateSubR1Computation());
+ auto result =
+ builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
+ operands, CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
}
@@ -270,25 +331,192 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
auto operands = builder.Tuple({operand1, operand2});
- auto result = builder.Conditional(pred, operands, CreateAddR1Computation(),
- operands, CreateSubR1Computation());
+ auto result =
+ builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
+ operands, CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
}
+// Test true and false computations that return a tuple of scalars.
+XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto operands = builder.Tuple(
+ {builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)});
+ builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
+ CreateR0TupleFloorComputation());
+
+ ComputeAndCompareTuple(
+ &builder,
+ *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
+ Literal::CreateR0<float>(25.0f).get()}),
+ {}, error_spec_);
+}
+
+// Test true and false computations that return a tuple of arrays.
+// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
+XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
+ builder.ConstantR1<float>({25.6f, 29.2f})});
+ builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
+ CreateR1TupleFloorComputation());
+
+ ComputeAndCompareTuple(
+ &builder,
+ *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
+ Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
+ {}, error_spec_);
+}
+
+// Test true and false computations that return a tuple of a predicate, a
+// scalar, and an array.
+// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
+XLA_TEST_F(ConditionalOpTest,
+ DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) {
+ ComputationBuilder true_builder(client_, TestName() + ".true");
+ {
+ true_builder.Parameter(0, empty_tuple_, "tuple");
+ auto true_pred = true_builder.ConstantR0<bool>(true);
+ auto true_scalar = true_builder.ConstantR0<float>(12.2f);
+ auto true_array = true_builder.ConstantR1<float>({12.8f, 14.6f});
+ true_builder.Tuple({true_pred, true_scalar, true_array});
+ }
+ auto true_builder_result = true_builder.Build();
+ EXPECT_IS_OK(true_builder_result.status());
+
+ ComputationBuilder false_builder(client_, TestName() + ".false");
+ {
+ false_builder.Parameter(0, empty_tuple_, "tuple");
+ auto false_pred = false_builder.ConstantR0<bool>(false);
+ auto false_scalar = false_builder.ConstantR0<float>(25.6f);
+ auto false_array = false_builder.ConstantR1<float>({26.4f, 32.6f});
+ false_builder.Tuple({false_pred, false_scalar, false_array});
+ }
+ auto false_builder_result = false_builder.Build();
+ EXPECT_IS_OK(false_builder_result.status());
+
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto operands = builder.Tuple({});
+ builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
+ operands, false_builder_result.ConsumeValueOrDie());
+
+ ComputeAndCompareTuple(
+ &builder,
+ *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
+ Literal::CreateR0<float>(12.2f).get(),
+ Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
+ {}, error_spec_);
+}
+
+// Test true and false computations that return a nested tuple.
+// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
+XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) {
+ ComputationBuilder true_builder(client_, TestName() + ".true");
+ {
+ true_builder.Parameter(0, empty_tuple_, "tuple");
+ auto true_constant1 = true_builder.ConstantR0<float>(12.2f);
+ auto true_constant2 = true_builder.ConstantR1<float>({12.8f, 14.6f});
+ auto true_constant3 = true_builder.ConstantR1<float>({25.4f, 29.8f});
+ auto true_constant4 = true_builder.ConstantR0<float>(35.6f);
+ true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}),
+ true_builder.Tuple({true_constant3, true_constant4})});
+ }
+ auto true_builder_result = true_builder.Build();
+ EXPECT_IS_OK(true_builder_result.status());
+
+ ComputationBuilder false_builder(client_, TestName() + ".false");
+ {
+ false_builder.Parameter(0, empty_tuple_, "tuple");
+ auto false_constant1 = false_builder.ConstantR0<float>(46.6f);
+ auto false_constant2 = false_builder.ConstantR1<float>({54.4f, 58.4f});
+ auto false_constant3 = false_builder.ConstantR1<float>({62.1f, 67.4f});
+ auto false_constant4 = false_builder.ConstantR0<float>(9.3f);
+ false_builder.Tuple(
+ {false_builder.Tuple({false_constant1, false_constant2}),
+ false_builder.Tuple({false_constant3, false_constant4})});
+ }
+ auto false_builder_result = false_builder.Build();
+ EXPECT_IS_OK(false_builder_result.status());
+
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto operands = builder.Tuple({});
+ builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
+ operands, false_builder_result.ConsumeValueOrDie());
+
+ ComputeAndCompareTuple(
+ &builder,
+ *Literal::MakeTuple(
+ {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
+ Literal::CreateR1<float>({54.4f, 58.4f}).get()})
+ .get(),
+ Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
+ Literal::CreateR0<float>(9.3f).get()})
+ .get()}),
+ {}, error_spec_);
+}
+
+// Test conditional that takes in scalar operands in the form of external
+// params.
+XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle pred, operand1, operand2;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
+ auto operand1_param =
+ CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
+ auto operand2_param =
+ CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
+ auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
+ operand2, CreateR0FloorComputation());
+
+ ComputeAndCompareR0<float>(
+ &builder, 57.0f,
+ {pred_arg.get(), operand1_param.get(), operand2_param.get()},
+ error_spec_);
+}
+
+// Test conditional that takes in array operands in the form of external params.
+XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle pred, operand1, operand2;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
+ auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
+ &builder, &operand1);
+ auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
+ &builder, &operand2);
+ auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(),
+ operand2, CreateR1FloorComputation());
+
+ ComputeAndCompareR1<float>(
+ &builder, {10.0f, 11.0f},
+ {pred_arg.get(), operand1_param.get(), operand2_param.get()},
+ error_spec_);
+}
+
// Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
- Shape r0bool = ShapeUtil::MakeShape(PRED, {});
- Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
- auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
- auto pred_cond = inner_builder.GetTupleElement(param0, 0);
- auto true_operand = inner_builder.GetTupleElement(param0, 1);
- auto false_operand = inner_builder.GetTupleElement(param0, 2);
- inner_builder.Conditional(pred_cond, true_operand,
- CreateR0F32CeilComputation(), false_operand,
- CreateR0F32FloorComputation());
+ {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
+ auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
+ auto pred_cond = inner_builder.GetTupleElement(param0, 0);
+ auto true_operand = inner_builder.GetTupleElement(param0, 1);
+ auto false_operand = inner_builder.GetTupleElement(param0, 2);
+ inner_builder.Conditional(pred_cond, true_operand,
+ CreateR0CeilComputation(), false_operand,
+ CreateR0FloorComputation());
+ }
auto inner_builder_result = inner_builder.Build();
+ EXPECT_IS_OK(inner_builder_result.status());
ComputationBuilder builder(client_, TestName());
auto pred1 = builder.ConstantR0<bool>(true);
@@ -299,7 +527,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
builder.Conditional(pred1, tuple_operand,
inner_builder_result.ConsumeValueOrDie(), operand3,
- CreateR0F32IdentityComputation());
+ CreateR0IdentityComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -311,8 +539,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateAddR1Computation(), operands,
- CreateSubR0Computation());
+ builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
auto result = builder.Build();
EXPECT_FALSE(result.ok());