aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 17:02:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 17:05:17 -0700
commit6f8b85d301140ce42c0aa4871750ee0aec758105 (patch)
tree6a7484d391c060cbe90158b31c979e8d3654f2cd /tensorflow/compiler/xla
parentb2644288ebfb8a4cc52231c3ca93a968397c860a (diff)
[XLA] Redesign: implement Tuple and GetTupleElement.
PiperOrigin-RevId: 190698245
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc33
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc22
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc99
7 files changed, 115 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index fcaf393b6b..7d39701b10 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -491,11 +491,40 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ std::vector<Shape> operand_shapes;
+ for (const XlaOp& e : elements) {
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(e));
+ operand_shapes.push_back(shape);
+ }
+ c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferVariadicOpShape(
+ HloOpcode::kTuple, operand_shape_ptrs));
+ return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+ }());
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
+ if (!ShapeUtil::IsTuple(tuple_shape)) {
+ return InvalidArgument(
+ "Operand to GetTupleElement() is not a tuple; got %s",
+ ShapeUtil::HumanString(tuple_shape).c_str());
+ }
+ *instr.mutable_shape() =
+ ShapeUtil::GetTupleElementShape(tuple_shape, index);
+
+ instr.set_tuple_index(index);
+
+ return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
+ {tuple_data});
+ }());
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 36456d552d..77e12d3602 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1070,6 +1070,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (const HloInstruction* operand : operands) {
operand_shapes.push_back(&operand->shape());
}
+ return InferVariadicOpShape(opcode, operand_shapes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
+ HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
operand_shapes);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 88830e6d25..9da2c99b41 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -85,6 +85,9 @@ class ShapeInference {
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ static StatusOr<Shape> InferVariadicOpShape(
+ HloOpcode opcode,
tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
// Infers the shape produced by applying the given mapping computation shape
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5ab25f2264..2fd97fa38e 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1011,6 +1011,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index ec95a68ead..4a9faef1dc 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -441,8 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
EXPECT_EQ(expected, actual->GetR1U8AsString());
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
@@ -453,8 +454,9 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
LiteralTestUtil::ExpectEqual(expected, *actual);
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
@@ -619,4 +621,20 @@ template void ClientLibraryTestBase::ComputeAndCompareLiteral(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout);
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 5ff200be03..be90f14c8e 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -217,11 +217,13 @@ class ClientLibraryTestBase : public ::testing::Test {
// Convenience method for running a built computation, transferring the
// result, and comparing it to the expected tuple literal.
+ template <typename BuilderT>
void ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename BuilderT>
void ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
// Convenience method for running a built computation and comparing the result
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index fa60af4b6a..098be6d7aa 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -41,7 +43,7 @@ class TupleTest : public ClientLibraryTestBase {
// Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest, TupleConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -54,13 +56,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
Literal::CreateR1<float>(constant_vector).get(),
Literal::CreateR2<float>(constant_matrix).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest, TupleScalarConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
@@ -68,13 +70,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
Literal::CreateR0<float>(constant_scalar2).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -82,9 +84,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
- builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
+ builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+ builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
auto expected =
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
@@ -95,9 +97,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
@@ -107,15 +109,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
- ComputationBuilder builder(client_, TestName());
- auto result = builder.Tuple({});
+ XlaBuilder builder(TestName());
+ builder.Tuple({});
auto expected = Literal::MakeTuple({});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElement) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -123,23 +125,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) {
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto tuple_data = builder.Tuple(
{builder.ConstantR1<float>({}),
builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto value = builder.ConstantR1<float>({4.5f});
builder.GetTupleElement(value, 1);
auto result_status = builder.Build();
@@ -152,7 +154,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -164,22 +166,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) {
auto matrix_element = builder.GetTupleElement(tuple_data, 1);
auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
- auto result = builder.Add(matrix_element, vector_element,
- /*broadcast_dimensions=*/{1});
+ builder.Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{2.f, 4.f, 6.f}, // row 0
{5.f, 7.f, 9.f}, // row 1
});
- ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
- ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3}));
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// Extracts both elements from a tuple and then puts them into a new tuple in
// the opposite order.
XLA_TEST_F(TupleTest, TupleGTEToTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -187,8 +189,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
- builder.GetTupleElement(tuple_data, 0)});
+ builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
auto expected =
Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
Literal::CreateR1<float>(constant_vector).get()});
@@ -196,8 +198,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
- ComputationBuilder b(client_, TestName());
- ComputationDataHandle v1, v2;
+ XlaBuilder b(TestName());
+ XlaOp v1, v2;
for (bool direction : {false, true}) {
std::unique_ptr<GlobalData> v1_data =
@@ -210,7 +212,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v2_gt = b.Gt(v2, v1); // true
auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
- auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
+ b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
Literal::CreateR0<bool>(!direction).get()});
@@ -237,7 +239,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
// \ (tuple10)-- /
// \ / \ /
// -----(GTE 0)-- --(GTE 1)----------
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -257,8 +259,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
auto addvectors = builder.Add(vector_from_01, vector_from_10);
auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
- auto result = builder.Add(addmatrices, addvectors,
- /*broadcast_dimensions=*/{1});
+ builder.Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{4.f, 8.f, 12.f}, // row 0
@@ -269,7 +271,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
// Tests a selection between tuples with "false" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -278,8 +280,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -314,7 +315,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
// Tests a selection between tuples with "true" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -323,8 +324,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
Literal::CreateR1<float>(vec2).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -333,7 +333,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
// Tests a selection between tuples but the final result is an element of the
// tuple, not the whole tuple.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -344,7 +344,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
auto select =
builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
- auto element = builder.GetTupleElement(select, 0);
+ builder.GetTupleElement(select, 0);
ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}
@@ -368,7 +368,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
// / --(GTE 1)--
// /
// (tuple 21)
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -384,8 +384,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
auto select2 =
builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
- auto result = builder.Add(builder.GetTupleElement(select2, 0),
- builder.GetTupleElement(select2, 1));
+ builder.Add(builder.GetTupleElement(select2, 0),
+ builder.GetTupleElement(select2, 1));
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
@@ -394,7 +394,7 @@ XLA_TEST_F(TupleTest,
DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -403,19 +403,18 @@ XLA_TEST_F(TupleTest,
auto tuple12 = builder.Tuple({c1, c2});
auto tuple21 = builder.Tuple({c2, c1});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto inner_tuple = builder.Tuple(
{builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
- auto outer_tuple =
- builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
auto expected_s = Literal::CreateR0<float>(42.0);
@@ -429,7 +428,7 @@ XLA_TEST_F(TupleTest, NestedTuples) {
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
@@ -460,7 +459,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
}
XLA_TEST_F(TupleTest, ComplexTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape c64r0 = ShapeUtil::MakeShape(C64, {});
Shape c64r1 = ShapeUtil::MakeShape(C64, {2});