aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-21 15:05:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-21 15:08:28 -0700
commit36647440d2e62cb494e4e6f6d5d9144ceb0b29c7 (patch)
treed97c715ef6c79b205442f254679b1ffa03be94e4 /tensorflow/compiler
parent57498a86c11dfc98dda84dc7318a3c84c85c6791 (diff)
Add methods to convert between Literals and ShapedBuffers to LocalClient. These "conversion" methods copy the data to/from the device into/from literals.
Also fix various issues I noticed along the way: * Move LocalClient tests into open source. * Add proper == operator to Literals * Add << overload for streaming Literals to output. * Add Literal::GetSubliteral methods. * Remove unused AllocatBufferOnDevice methods from LocalClient and LocalService. PiperOrigin-RevId: 169606342
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc66
-rw-r--r--tensorflow/compiler/xla/client/local_client.h21
-rw-r--r--tensorflow/compiler/xla/literal_util.cc43
-rw-r--r--tensorflow/compiler/xla/literal_util.h16
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc134
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc2
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc15
-rw-r--r--tensorflow/compiler/xla/service/local_service.h8
-rw-r--r--tensorflow/compiler/xla/tests/BUILD48
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc105
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc618
13 files changed, 959 insertions, 121 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index b36933436c..a0fc230319 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -166,7 +166,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
backend.platform()->Name().c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
@@ -225,7 +225,7 @@ tensorflow::Status LocalExecutable::RecordArguments(
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
*session_module->add_arguments() = literal.ToProto();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
tensorflow::Status LocalExecutable::RecordResult(
@@ -234,7 +234,7 @@ tensorflow::Status LocalExecutable::RecordResult(
Literal literal(session_module->result());
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
*session_module->mutable_result() = literal.ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
@@ -248,14 +248,6 @@ tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
shaped_buffer.shape(), literal);
}
-StatusOr<std::unique_ptr<GlobalData>> LocalClient::AllocateBufferOnDevice(
- const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) {
- TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
- local_service_->AllocateBufferOnDevice(
- shape, device_ordinal, allocate_space_for_deep_copy));
- return std::unique_ptr<GlobalData>(new GlobalData(local_service_, handle));
-}
-
se::Platform* LocalClient::platform() const {
return local_service_->backend().platform();
}
@@ -297,4 +289,56 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
device_ordinal, options));
}
+// Copy the literal data to the device with the given ordinal and return as a
+// ScopedShapedBuffer. The given memory allocator is used for device memory
+// allocation.
+StatusOr<std::unique_ptr<ScopedShapedBuffer>>
+LocalClient::LiteralToShapedBuffer(const Literal& literal,
+ DeviceMemoryAllocator* allocator,
+ int device_ordinal) {
+ TF_ASSIGN_OR_RETURN(auto scoped_buffer,
+ ScopedShapedBuffer::MakeScopedShapedBuffer(
+ literal.shape(), allocator, device_ordinal));
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
+ backend().stream_executor(device_ordinal));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ // This is a leaf of the shape. Transfer the literal array data to the
+ // device buffer.
+ return backend().transfer_manager()->TransferLiteralToDevice(
+ executor, literal.GetSubliteral(index),
+ scoped_buffer->mutable_buffer(index));
+ }
+ return Status::OK();
+ }));
+ return std::move(scoped_buffer);
+}
+
+// Copy the data from the device contained in the given ShapedBuffer and
+// return as a Literal.
+StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer) {
+ std::unique_ptr<Literal> literal =
+ Literal::CreateFromShape(shaped_buffer.shape());
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ backend().stream_executor(shaped_buffer.device_ordinal()));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ literal->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ // This is a leaf of the shape. Transfer the device buffer into the
+ // literal. The layout of the literal and the device buffer are
+ // necessarily the same so we pass 'subshape' for both device and
+ // literal shapes.
+ return backend().transfer_manager()->TransferLiteralFromDevice(
+ executor, shaped_buffer.buffer(index),
+ /*device_shape=*/subshape,
+ /*literal_shape*/ subshape, &literal->GetSubliteral(index));
+ }
+ return Status::OK();
+ }));
+ return std::move(literal);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index c903cd2711..e98384238a 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -158,15 +158,6 @@ class LocalClient : public Client {
LocalClient(const LocalClient&) = delete;
void operator=(const LocalClient&) = delete;
- // Return a handle to a buffer large enough to hold shape, allocated
- // on device_ordinal on the local service. If
- // allocate_space_for_deep_copy, the buffer is large enough to hold
- // all sub-buffers of a tuple shape, otherwise it is only as large
- // as the top-level tuple pointer array.
- StatusOr<std::unique_ptr<GlobalData>> AllocateBufferOnDevice(
- const Shape& shape, int device_ordinal,
- bool allocate_space_for_deep_copy);
-
// Build and return a LocalExecutable object. The executable is compiled using
// the given argument layouts and options.
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
@@ -174,6 +165,18 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options);
+ // Copy the literal data to the device with the given ordinal and return as a
+ // ScopedShapedBuffer. The given memory allocator is used for device memory
+ // allocation.
+ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LiteralToShapedBuffer(
+ const Literal& literal, DeviceMemoryAllocator* allocator,
+ int device_ordinal);
+
+ // Copy the data from the device contained in the given ShapedBuffer and
+ // return as a Literal.
+ StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer);
+
// Returns the platform that the underlying service targets.
perftools::gputools::Platform* platform() const;
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index e6787361d4..b867cf42d1 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -36,6 +36,11 @@ limitations under the License.
namespace xla {
+std::ostream& operator<<(std::ostream& out, const Literal& literal) {
+ out << literal.ToString();
+ return out;
+}
+
Literal::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
tensorflow::gtl::ArraySlice<int64> dimensions)
@@ -927,16 +932,16 @@ bool EqualElements(const Literal& literal1, const Literal& literal2,
} // namespace
-bool Literal::Equal(const Literal& literal2) const {
- if (!ShapeUtil::Compatible(shape(), literal2.shape())) {
+bool Literal::operator==(const Literal& other) const {
+ if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
if (ShapeUtil::IsTuple(shape())) {
// Because the shapes are compatible, they must have the same number of
// tuple elements.
- CHECK_EQ(tuple_literals_size(), literal2.tuple_literals_size());
+ CHECK_EQ(tuple_literals_size(), other.tuple_literals_size());
for (int i = 0; i < tuple_literals_size(); ++i) {
- if (!tuple_literals(i).Equal(literal2.tuple_literals(i))) {
+ if (tuple_literals(i) != other.tuple_literals(i)) {
return false;
}
}
@@ -945,23 +950,23 @@ bool Literal::Equal(const Literal& literal2) const {
std::vector<int64> multi_index(ShapeUtil::Rank(shape()), 0);
switch (shape().element_type()) {
case PRED:
- return EqualElements<bool>(*this, literal2, 0, &multi_index);
+ return EqualElements<bool>(*this, other, 0, &multi_index);
case U8:
- return EqualElements<uint8>(*this, literal2, 0, &multi_index);
+ return EqualElements<uint8>(*this, other, 0, &multi_index);
case S32:
- return EqualElements<int32>(*this, literal2, 0, &multi_index);
+ return EqualElements<int32>(*this, other, 0, &multi_index);
case S64:
- return EqualElements<int64>(*this, literal2, 0, &multi_index);
+ return EqualElements<int64>(*this, other, 0, &multi_index);
case U32:
- return EqualElements<uint32>(*this, literal2, 0, &multi_index);
+ return EqualElements<uint32>(*this, other, 0, &multi_index);
case U64:
- return EqualElements<uint64>(*this, literal2, 0, &multi_index);
+ return EqualElements<uint64>(*this, other, 0, &multi_index);
case F32:
- return EqualElements<float>(*this, literal2, 0, &multi_index);
+ return EqualElements<float>(*this, other, 0, &multi_index);
case F64:
- return EqualElements<double>(*this, literal2, 0, &multi_index);
+ return EqualElements<double>(*this, other, 0, &multi_index);
case F16:
- return EqualElements<half>(*this, literal2, 0, &multi_index);
+ return EqualElements<half>(*this, other, 0, &multi_index);
default:
LOG(FATAL) << "Unimplemented: Literal::Equal for type "
<< PrimitiveType_Name(shape().element_type());
@@ -1400,4 +1405,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
}
}
+const Literal& Literal::GetSubliteral(const ShapeIndex& index) const {
+ return const_cast<Literal*>(this)->GetSubliteral(index);
+}
+
+Literal& Literal::GetSubliteral(const ShapeIndex& index) {
+ Literal* subliteral = this;
+ for (int64 i : index) {
+ subliteral = &subliteral->tuple_literals_.at(i);
+ }
+ return *subliteral;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 858ddc297b..e8cee732d4 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <initializer_list>
#include <iterator>
#include <memory>
+#include <ostream>
#include <string>
#include <type_traits>
#include <vector>
@@ -66,6 +67,11 @@ class Literal {
Literal& operator=(const Literal& other) = default;
Literal& operator=(Literal&&) = default;
+ // Literals are equal if they have compatible shapes and the same data
+ // values. Layout is not checked.
+ bool operator==(const Literal& other) const;
+ bool operator!=(const Literal& other) const { return !(*this == other); }
+
LiteralProto ToProto() const;
bool has_shape() const {
@@ -77,6 +83,10 @@ class Literal {
string DebugString() const { return ToProto().DebugString(); }
string ShortDebugString() const { return ToProto().ShortDebugString(); }
+ // Return the nested literal at the given shape index.
+ const Literal& GetSubliteral(const ShapeIndex& index) const;
+ Literal& GetSubliteral(const ShapeIndex& index);
+
void Clear() {
shape_.Clear();
u8s_.clear();
@@ -518,10 +528,6 @@ class Literal {
template <typename NativeT>
void Resize(int64 num_elements, NativeT value);
- // Returns true if this literal has the same shape and value as the given
- // literal. Layout is not considered in the comparison.
- bool Equal(const Literal& literal2) const;
-
// Returns whether every element in this literal is equal to value.
//
// value is an int8 because we expect this to be called with small
@@ -597,6 +603,8 @@ class Literal {
std::vector<Literal> tuple_literals_;
};
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
+
// Declarations of template specializations for GetArraySlice and
// GetMutableArraySlice. The specializations map native type to XLA primitive
// type.
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 61ceac4f9a..e7dedd0821 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -257,37 +257,37 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
}
TEST_F(LiteralUtilTest, ScalarEquality) {
- // Test Literal::Equal with scalars.
+ // Test equality with scalars.
auto f32_42 = Literal::CreateR0<float>(42.0);
auto f32_42_clone = Literal::CreateR0<float>(42.0);
- EXPECT_TRUE(f32_42->Equal(*f32_42));
- EXPECT_TRUE(f32_42->Equal(*f32_42_clone));
+ EXPECT_EQ(*f32_42, *f32_42);
+ EXPECT_EQ(*f32_42, *f32_42_clone);
auto f32_123 = Literal::CreateR0<float>(123.0);
- EXPECT_FALSE(f32_42->Equal(*f32_123));
+ EXPECT_NE(*f32_42, *f32_123);
auto f64_42 = Literal::CreateR0<double>(42.0);
- EXPECT_FALSE(f32_42->Equal(*f64_42));
+ EXPECT_NE(*f32_42, *f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
- // Test Literal::Equal with nonscalars.
+ // Test equality with nonscalars.
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto scalar = Literal::CreateR0<float>(1.0);
- EXPECT_TRUE(matrix->Equal(*matrix));
- EXPECT_TRUE(matrix->Equal(*matrix_clone));
- EXPECT_FALSE(matrix->Equal(*matrix_different));
- EXPECT_FALSE(matrix->Equal(*vector_literal));
- EXPECT_FALSE(matrix->Equal(*scalar));
+ EXPECT_EQ(*matrix, *matrix);
+ EXPECT_EQ(*matrix, *matrix_clone);
+ EXPECT_NE(*matrix, *matrix_different);
+ EXPECT_NE(*matrix, *vector_literal);
+ EXPECT_NE(*matrix, *scalar);
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
- // Test Literal::Equal with literals which have different layouts.
+ // Test equality with literals which have different layouts.
auto colmajor = MakeUnique<Literal>();
*colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
*colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
@@ -306,11 +306,11 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
rowmajor->Set<float>({1, 0}, 3.0);
rowmajor->Set<float>({1, 1}, 4.0);
- EXPECT_TRUE(rowmajor->Equal(*colmajor));
+ EXPECT_EQ(*rowmajor, *colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
- // Test Literal::Equal with tuples.
+ // Test equality with tuples.
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()});
@@ -319,16 +319,16 @@ TEST_F(LiteralUtilTest, TupleEquality) {
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = Literal::CreateR0<float>(1.0);
auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_TRUE(tuple1->Equal(*tuple2));
+ EXPECT_EQ(*tuple1, *tuple2);
// Tuple with elements reversed.
auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_FALSE(tuple1->Equal(*reversed_tuple));
+ EXPECT_NE(*tuple1, *reversed_tuple);
// Tuple with different value.
auto scalar_42 = Literal::CreateR0<float>(42.0);
auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_FALSE(tuple1->Equal(*different_tuple));
+ EXPECT_NE(*tuple1, *different_tuple);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
@@ -348,7 +348,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_TRUE(tuple->Equal(*x));
+ EXPECT_EQ(*tuple, *x);
}
TEST_F(LiteralUtilTest, IsAll) {
@@ -439,17 +439,17 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
auto data01 = data->Relayout(layout01);
EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_TRUE(data->Equal(*data01));
+ EXPECT_EQ(*data, *data01);
auto data10 = data->Relayout(layout10);
EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_TRUE(data->Equal(*data10));
+ EXPECT_EQ(*data, *data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = Literal::CreateR0<float>(1.7f);
auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie();
- EXPECT_TRUE(original->Equal(*reshape));
+ EXPECT_EQ(*original, *reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -469,7 +469,7 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
// clang-format on
auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_TRUE(expected->Equal(*reshape));
+ EXPECT_EQ(*expected, *reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -489,13 +489,13 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
// clang-format on
auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_TRUE(expected->Equal(*reshape));
+ EXPECT_EQ(*expected, *reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = Literal::CreateR0<float>(1.7f);
auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_TRUE(original->Equal(*reshape));
+ EXPECT_EQ(*original, *reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -521,13 +521,11 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_TRUE(
- literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major));
+ EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_TRUE(
- literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor));
+ EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
@@ -596,14 +594,14 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = Literal::CreateR0<int32>(1);
auto result = input->Slice({}, {});
- EXPECT_TRUE(input->Equal(*result));
+ EXPECT_EQ(*input, *result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
auto result = input->Slice({3}, {4});
auto expected = Literal::CreateR1<float>({4.0});
- EXPECT_TRUE(expected->Equal(*result));
+ EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
@@ -611,49 +609,49 @@ TEST_F(LiteralUtilTest, SliceR2U32) {
Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto result = input_3x4->Slice({0, 2}, {2, 4});
auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_TRUE(expected->Equal(*result));
+ EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = Literal::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_TRUE(input_2x3x2->Equal(*result));
+ EXPECT_EQ(*input_2x3x2, *result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output;
output.PopulateR1<int64>({77});
auto expected = Literal::CreateR1<int64>({77});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR2U64) {
Literal output;
output.PopulateR1<uint64>({{77, 88}});
auto expected = Literal::CreateR1<uint64>({{77, 88}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
auto expected = Literal::CreateR0<float>(2.5f);
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output;
output.PopulateWithValue<int64>(-7, {3});
auto expected = Literal::CreateR1<int64>({-7, -7, -7});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output;
output.PopulateWithValue<uint64>(42, {2, 2});
auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -661,7 +659,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h, {});
auto expected = Literal::CreateR0<half>(h);
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -669,7 +667,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h, {3});
auto expected = Literal::CreateR1<half>({h, h, h});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -677,7 +675,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h, {2, 2});
auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
@@ -688,7 +686,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) {
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_TRUE(output->Equal(*expected));
+ EXPECT_EQ(*output, *expected);
}
TEST_F(LiteralUtilTest, Copy) {
@@ -741,7 +739,7 @@ TEST_F(LiteralUtilTest, CopyScalars) {
auto zero = Literal::CreateR0<uint32>(0);
auto nine = Literal::CreateR0<uint32>(9);
TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {}));
- EXPECT_TRUE(zero->Equal(*nine));
+ EXPECT_EQ(*zero, *nine);
auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {}));
@@ -761,7 +759,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
auto nine = Literal::CreateR1<float>({9});
TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0}));
- EXPECT_TRUE(nine->Equal(*const_nine));
+ EXPECT_EQ(*nine, *const_nine);
}
{
@@ -770,7 +768,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
auto nine = Literal::CreateR1<float>({9});
TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0}));
- EXPECT_TRUE(empty->Equal(*const_empty));
+ EXPECT_EQ(*empty, *const_empty);
}
}
@@ -863,7 +861,7 @@ TEST_F(LiteralUtilTest, ConvertR4) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
original->Convert(U32));
- EXPECT_TRUE(expected->Equal(*converted));
+ EXPECT_EQ(*expected, *converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -925,43 +923,43 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
std::unique_ptr<Literal> conv;
conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*u32));
+ EXPECT_EQ(*conv, *u32);
conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*u64));
+ EXPECT_EQ(*conv, *u64);
conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s64));
+ EXPECT_EQ(*conv, *s64);
conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*pred));
+ EXPECT_EQ(*conv, *pred);
conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*int32_pred));
+ EXPECT_EQ(*conv, *int32_pred);
conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f32));
+ EXPECT_EQ(*conv, *f32);
conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
tensorflow::error::INVALID_ARGUMENT);
@@ -1045,5 +1043,25 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
ASSERT_EQ(h1, r[3]);
}
+TEST_F(LiteralUtilTest, Subliterals) {
+ auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+
+ EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get());
+ EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get());
+ EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get());
+ EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get());
+
+ EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar);
+ EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix);
+
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 74f8e3143d..f7551bfb6c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1376,7 +1376,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
// try to get more fancy about proving equivalence in cases beyond that.
if (pad_value->opcode() != HloOpcode::kConstant ||
reduce_init_value->opcode() != HloOpcode::kConstant ||
- !pad_value->literal().Equal(reduce_init_value->literal())) {
+ pad_value->literal() != reduce_init_value->literal()) {
VLOG(10) << "Not folding pad into reduce-window due to different pad "
"values.";
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 0fef89a06d..cdccacdd2d 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -68,7 +68,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
auto range = constants.equal_range(shape_string);
HloInstruction* match = nullptr;
for (auto it = range.first; it != range.second; ++it) {
- if (instruction->literal().Equal(it->second->literal())) {
+ if (instruction->literal() == it->second->literal()) {
match = it->second;
break;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 3366b83fdd..a883429341 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1293,7 +1293,7 @@ bool HloInstruction::IdenticalSlowPath(
// A constant is defined by the value in the literal.
case HloOpcode::kConstant:
- return literal().Equal(other.literal());
+ return literal() == other.literal();
// A convert result is determined by the primitive type that the operand is
// converted into.
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 1eb4edbe3e..3235081f83 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -88,21 +88,6 @@ int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
}
} // namespace
-StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
- const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) {
- int64 allocation_size = RequiredSpace(shape, allocate_space_for_deep_copy,
- execute_backend_->transfer_manager());
-
- TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
- execute_backend_->memory_allocator()->Allocate(
- device_ordinal, allocation_size));
-
- return allocation_tracker_.Register(
- execute_backend_.get(), device_ordinal, allocation, shape,
- tensorflow::strings::StrCat("AllocateBufferOnDevice of size ",
- allocation_size));
-}
-
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index c90943f3c0..f2bfb960f4 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -39,14 +39,6 @@ class LocalService : public Service {
static StatusOr<std::unique_ptr<LocalService>> NewService(
const ServiceOptions& options);
- // Return a handle to a buffer large enough to hold shape, allocated
- // on device_ordinal. If allocate_space_for_deep_copy, the buffer is
- // large enough to hold all sub-buffers of a tuple shape, otherwise
- // it is only as large as the top-level tuple pointer array.
- StatusOr<GlobalDataHandle> AllocateBufferOnDevice(
- const Shape& shape, int device_ordinal,
- bool allocate_space_for_deep_copy);
-
// Builds an Executable with the given argument layouts and options. If
// result_layout is non-null, then the executable is compiled to produce a
// result of the given layout.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 33537877ea..e45b839afd 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1439,6 +1439,54 @@ tf_cc_test(
],
)
+xla_test(
+ name = "local_client_allocation_test",
+ srcs = ["local_client_allocation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:local_client_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "local_client_execute_test",
+ srcs = ["local_client_execute_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:local_client_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_test(
name = "hlo_metadata_test",
srcs = [
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
new file mode 100644
index 0000000000..6897f0291a
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class LocalClientAllocationTest : public LocalClientTestBase {
+ protected:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
+ auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ builder.Add(x, y);
+
+ TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+
+ int64 allocation_count_before = allocator_->allocation_count();
+
+ // Override the allocator via 'options'. Tests that allocation and
+ // deallocation happen on the right allocator.
+ ExecutableRunOptions options;
+ options.set_allocator(allocator);
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {},
+ DefaultExecutableBuildOptions(), options);
+
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+
+ // At least one allocation should have been performed when executing the
+ // computation.
+ EXPECT_GT(allocator_->allocation_count(), allocation_count_before);
+
+ // Deallocate result and verify that deallocate was called once.
+ int64 deallocation_count_before = allocator_->deallocation_count();
+ result = nullptr;
+ EXPECT_EQ(deallocation_count_before + 1, allocator_->deallocation_count());
+}
+
+XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
+ // Run a computation on every device on the system. Verify that allocation
+ // occurs on the proper device.
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
+ auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ builder.Add(x, y);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
+ for (int d = 0; d < local_client_->device_count(); ++d) {
+ if (!local_client_->device_ordinal_supported(d)) {
+ continue;
+ }
+
+ int64 device_allocation_count_before = allocator->allocation_count(d);
+ int64 allocation_count_before = allocator->allocation_count();
+
+ auto result = ExecuteLocallyOrDie(
+ computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
+ ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+
+ // At least one allocation should have been performed when executing the
+ // computation.
+ EXPECT_GT(allocator->allocation_count(), allocation_count_before);
+ EXPECT_GT(allocator->allocation_count(d), device_allocation_count_before);
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
new file mode 100644
index 0000000000..ef2592e292
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -0,0 +1,618 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+
+using ::testing::ContainsRegex;
+
+class LocalClientExecuteTest : public LocalClientTestBase {
+ protected:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(LocalClientExecuteTest, Constant) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto y = builder.ConstantR0<float>(123.0f);
+
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
+
+ LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(*result),
+ error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.ConstantR0<float>(123.0f);
+ builder.Add(x, y);
+
+ auto x_value = LiteralToScopedShapedBuffer(*Literal::CreateR0<float>(42.0f));
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_value.get()});
+
+ LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(*result),
+ error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x");
+ auto y = builder.ConstantR1<float>({});
+ builder.Add(x, y);
+
+ auto x_array = LiteralToScopedShapedBuffer(*Literal::CreateR1<float>({}));
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()});
+
+ LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(*result),
+ error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ builder.Add(x, y);
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()});
+
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ builder.Add(x, y);
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ ExecutionProfile profile;
+ std::unique_ptr<ScopedShapedBuffer> result = ExecuteLocallyOrDie(
+ builder.Build().ValueOrDie(), {x_array.get()},
+ DefaultExecutableBuildOptions(),
+ DefaultExecutableRunOptions().set_execution_profile(&profile));
+
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+ EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ builder.Add(x, y);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Create x as a col-major array.
+ auto x_array = LiteralToScopedShapedBuffer(
+ *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*minor_to_major=*/{0, 1}));
+ EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
+ LayoutUtil::MakeLayout({0, 1})));
+
+ // Create y as a row-major array.
+ auto y_array = LiteralToScopedShapedBuffer(
+ *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ /*minor_to_major=*/{1, 0}));
+ EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
+ LayoutUtil::MakeLayout({1, 0})));
+
+ std::unique_ptr<ScopedShapedBuffer> result_colmaj =
+ ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()});
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ *ShapedBufferToLiteral(*result_colmaj),
+ error_spec_);
+
+ // Run with the parameter values in a different order.
+ std::unique_ptr<ScopedShapedBuffer> result_param_swap =
+ ExecuteLocallyOrDie(computation, {y_array.get(), x_array.get()});
+ LiteralTestUtil::ExpectR2Near<float>(
+ {{11.0f, 22.0f}, {33.0f, 44.0f}},
+ *ShapedBufferToLiteral(*result_param_swap), error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ builder.Add(x, y);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ auto y_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+
+ // Run with col-major result layout.
+ std::unique_ptr<ScopedShapedBuffer> result_colmaj = ExecuteLocallyOrDie(
+ computation, {x_array.get(), y_array.get()},
+ DefaultExecutableBuildOptions().set_result_layout(
+ ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
+ DefaultExecutableRunOptions());
+ EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(),
+ LayoutUtil::MakeLayout({0, 1})));
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ *ShapedBufferToLiteral(*result_colmaj),
+ error_spec_);
+
+ // Run with row-major result layout.
+ std::unique_ptr<ScopedShapedBuffer> result_rowmaj = ExecuteLocallyOrDie(
+ computation, {x_array.get(), y_array.get()},
+ DefaultExecutableBuildOptions().set_result_layout(
+ ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
+ DefaultExecutableRunOptions());
+ EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(),
+ LayoutUtil::MakeLayout({1, 0})));
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ *ShapedBufferToLiteral(*result_rowmaj),
+ error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ builder.Tuple({x, y, x});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ auto y_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()});
+
+ EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
+ EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape()));
+
+ std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ result_literal->tuple_literals(0));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ result_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ result_literal->tuple_literals(2));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ auto inner_tuple = builder.Tuple({x, y, x});
+ builder.Tuple({inner_tuple, x});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ auto y_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()});
+
+ EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+
+ std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ result_literal->tuple_literals(1));
+ const Literal& inner_tuple_literal = result_literal->tuple_literals(0);
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ inner_tuple_literal.tuple_literals(0));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ inner_tuple_literal.tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ inner_tuple_literal.tuple_literals(2));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
+ // Verify setting the result layout of a computation with a tuple output.
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ builder.Tuple({x, y});
+
+ auto array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+
+ ExecutableBuildOptions options = DefaultExecutableBuildOptions();
+ Shape shape_with_layout = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{1, 0})});
+ options.set_result_layout(shape_with_layout);
+ std::unique_ptr<ScopedShapedBuffer> result = ExecuteLocallyOrDie(
+ builder.Build().ValueOrDie(), {array.get(), array.get()}, options,
+ DefaultExecutableRunOptions());
+
+ std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ result_literal->tuple_literals(0));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ result_literal->tuple_literals(1));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
+ // Test passing in an invalid number of arguments.
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y");
+ builder.Add(x, y);
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ auto execute_status =
+ ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()});
+
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("invalid number of arguments"));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
+ // Test passing in an argument with the wrong shape.
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
+ builder.Neg(x);
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ auto execute_status =
+ ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()});
+
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("invalid argument shape"))
+ << execute_status.status();
+}
+
+XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
+ // Test passing in an invalid result layout parameter.
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ builder.Neg(x);
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ auto execute_status = ExecuteLocally(
+ builder.Build().ValueOrDie(), {x_array.get()},
+ DefaultExecutableBuildOptions().set_result_layout(
+ ShapeUtil::MakeShapeWithLayout(F32,
+ /*dimensions=*/{1, 2, 3, 4},
+ /*minor_to_major=*/{0, 1, 2, 3})),
+ DefaultExecutableRunOptions());
+
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("not compatible with result shape"))
+ << execute_status.status();
+}
+
+XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
+ // Try to run a trivial computation on every device on the system. If a
+ // specific device is not supported, check that the right error is returned.
+ ComputationBuilder builder(local_client_, TestName());
+ builder.ConstantR0<float>(42.0f);
+ auto computation = builder.Build().ConsumeValueOrDie();
+ for (int d = 0; d < local_client_->device_count(); ++d) {
+ if (!local_client_->device_ordinal_supported(d)) {
+ auto execute_status =
+ ExecuteLocally(computation, {},
+ DefaultExecutableBuildOptions().set_device_ordinal(d),
+ DefaultExecutableRunOptions().set_device_ordinal(d));
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("device .* not supported"));
+ } else {
+ auto result = ExecuteLocallyOrDie(
+ computation, {},
+ DefaultExecutableBuildOptions().set_device_ordinal(d),
+ DefaultExecutableRunOptions().set_device_ordinal(d));
+ EXPECT_EQ(d, result->device_ordinal());
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f,
+ *ShapedBufferToLiteral(*result));
+ }
+ }
+}
+
+XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
+ // Try running computations on devices with device ordinal values which do not
+ // exist.
+ ComputationBuilder builder(local_client_, TestName());
+ builder.ConstantR0<float>(42.0f);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto execute_status =
+ ExecuteLocally(computation, {},
+ DefaultExecutableBuildOptions().set_device_ordinal(
+ local_client_->device_count()),
+ DefaultExecutableRunOptions().set_device_ordinal(
+ local_client_->device_count()));
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("Invalid device ordinal value"));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
+ // Run a computation on a specific stream on each device on the system.
+ ComputationBuilder builder(local_client_, TestName());
+ builder.ConstantR0<float>(42.0f);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ for (int d = 0; d < local_client_->device_count(); ++d) {
+ if (!local_client_->device_ordinal_supported(d)) {
+ continue;
+ }
+ se::StreamExecutor* executor =
+ local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
+ se::Stream stream(executor);
+ stream.Init();
+
+ auto result =
+ ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
+ DefaultExecutableRunOptions().set_stream(&stream));
+ // As a check to verify that the computation ran of the device associated
+ // with the stream. This is a weak check, but stronger verification is hard.
+ EXPECT_EQ(d, result->device_ordinal());
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f,
+ *ShapedBufferToLiteral(*result));
+ }
+}
+
+// Disable this test on CPU because we're using the CPU as the platform
+// which does not match the service platform.
+XLA_TEST_F(LocalClientExecuteTest,
+ DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
+ // Try to run a computation on a stream for a platform (CPU) which does not
+ // match the platform of the service (!= CPU).
+ se::Platform* wrong_platform =
+ se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
+ .ValueOrDie();
+ se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
+ wrong_stream.Init();
+
+ ComputationBuilder builder(local_client_, TestName());
+ builder.ConstantR0<float>(42.0f);
+ auto execute_status = ExecuteLocally(
+ builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
+ DefaultExecutableRunOptions().set_stream(&wrong_stream));
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("stream is for platform .*, but service targets"));
+}
+
+XLA_TEST_F(LocalClientExecuteTest,
+ DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
+ se::Platform* wrong_platform =
+ se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
+ .ValueOrDie();
+ TestAllocator allocator(wrong_platform);
+
+ ComputationBuilder builder(local_client_, TestName());
+ auto y = builder.ConstantR0<float>(123.0f);
+
+ auto execute_status = ExecuteLocally(
+ builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
+ DefaultExecutableRunOptions().set_allocator(&allocator));
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("allocator platform .* does not match service"));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
+ // Try to run a computation on a stream that has not been initialized.
+ ComputationBuilder builder(local_client_, TestName());
+ builder.ConstantR0<float>(42.0f);
+
+ LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
+ se::StreamExecutor* executor =
+ local_client_->platform()
+ ->ExecutorForDevice(local_client_->default_device_ordinal())
+ .ValueOrDie();
+ se::Stream stream(executor);
+ // Don't call stream.Init().
+
+ auto execute_status = ExecuteLocally(
+ builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
+ DefaultExecutableRunOptions().set_stream(&stream));
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_THAT(execute_status.status().error_message(),
+ ContainsRegex("stream is uninitialized or in an error state"));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
+ ComputationBuilder builder(local_client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+
+ std::unique_ptr<ScopedShapedBuffer> result =
+ ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
+ std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(*result);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
+ tuple_literal->tuple_literals(0));
+ LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
+ tuple_literal->tuple_literals(1));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
+ ComputationBuilder builder(local_client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ builder.Add(x, y);
+
+ Shape argument_layout =
+ ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
+ auto executable_status =
+ local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
+ ExecutableBuildOptions());
+ ASSERT_IS_OK(executable_status);
+ std::unique_ptr<LocalExecutable> executable =
+ executable_status.ConsumeValueOrDie();
+
+ auto x_array = LiteralToScopedShapedBuffer(
+ *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ std::unique_ptr<ScopedShapedBuffer> result = ShapedBufferToScopedShapedBuffer(
+ executable->Run({x_array.get()}, DefaultExecutableRunOptions())
+ .ConsumeValueOrDie(),
+ allocator_);
+
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+}
+
+XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
+ // Test copying Literals to the device as ShapedBuffers, then copying them
+ // back again to Literals.
+ auto test_to_device_and_back = [this](const Literal& literal) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto shaped_buffer,
+ local_client_->LiteralToShapedBuffer(
+ literal, allocator_, local_client_->default_device_ordinal()));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto transferred_literal,
+ local_client_->ShapedBufferToLiteral(*shaped_buffer));
+ EXPECT_EQ(literal, *transferred_literal);
+ };
+
+ // Array shapes.
+ test_to_device_and_back(*Literal::CreateR0<float>(42.0));
+ test_to_device_and_back(*Literal::CreateR0<bool>(true));
+ test_to_device_and_back(*Literal::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(
+ *Literal::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(*Literal::CreateR2<int32>({{2, 1}, {4444, 56}}));
+
+ // Null shape (empty tuple).
+ test_to_device_and_back(*Literal::MakeTuple({}));
+
+ // Non-nested tuples.
+ test_to_device_and_back(
+ *Literal::MakeTuple({Literal::CreateR0<float>(12223.0).get()}));
+ test_to_device_and_back(
+ *Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
+ Literal::CreateR0<float>(123456.0).get()}));
+
+ // Nested tuple.
+ test_to_device_and_back(*Literal::MakeTuple(
+ {Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
+ Literal::CreateR0<float>(123456.0).get()})
+ .get(),
+ Literal::CreateR0<bool>(false).get()}));
+}
+
+// Benchmark that measures the overhead of the LocalClient API when running a
+// trivial computation
+void BM_LocalClientOverhead(int num_iters) {
+ tensorflow::testing::StopTiming();
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+ auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+ StreamExecutorMemoryAllocator allocator(platform, executors);
+ LocalClient* client =
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
+ auto* transfer_manager =
+ TransferManager::GetForPlatform(platform).ValueOrDie();
+ int device_ordinal = client->default_device_ordinal();
+
+ // Use a tiny add operation as the computation.
+ ComputationBuilder builder(client, "Add");
+ auto shape = ShapeUtil::MakeShape(F32, {2, 3});
+ auto x = builder.Parameter(0, shape, "x");
+ builder.Add(x, x);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(shape, &allocator, 0)
+ .ConsumeValueOrDie();
+ auto literal = Literal::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *literal, buffer->mutable_buffer({})));
+
+ const int kWarmups = 2;
+
+ auto executable_status = client->Compile(computation, {&buffer->shape()},
+ ExecutableBuildOptions());
+ ASSERT_IS_OK(executable_status);
+ std::unique_ptr<LocalExecutable> executable =
+ executable_status.ConsumeValueOrDie();
+
+ se::Stream stream(executors[client->default_device_ordinal()]);
+ stream.Init();
+
+ ExecutableRunOptions run_options;
+ run_options.set_allocator(&allocator).set_stream(&stream);
+
+ for (int i = 0; i < kWarmups; ++i) {
+ auto result = executable->Run({buffer.get()}, run_options);
+ ASSERT_IS_OK(result);
+ }
+
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < num_iters; ++i) {
+ auto result = executable->Run({buffer.get()}, run_options);
+ ASSERT_IS_OK(result);
+ }
+}
+
+BENCHMARK(BM_LocalClientOverhead);
+
+} // namespace
+} // namespace xla