diff options
author | 2017-09-21 15:05:33 -0700 | |
---|---|---|
committer | 2017-09-21 15:08:28 -0700 | |
commit | 36647440d2e62cb494e4e6f6d5d9144ceb0b29c7 (patch) | |
tree | d97c715ef6c79b205442f254679b1ffa03be94e4 /tensorflow/compiler | |
parent | 57498a86c11dfc98dda84dc7318a3c84c85c6791 (diff) |
Add methods to convert between Literals and ShapedBuffers to LocalClient. These "conversion" methods copy the data to/from the device into/from literals.
Also fix various issues I noticed along the way:
* Move LocalClient tests into open source.
* Add proper == operator to Literals
* Add << overload for streaming Literals to output.
* Add Literal::GetSubliteral methods.
* Remove unused AllocatBufferOnDevice methods from LocalClient and LocalService.
PiperOrigin-RevId: 169606342
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/client/local_client.cc | 66 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/local_client.h | 21 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 43 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util_test.cc | 134 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/local_service.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/local_service.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 48 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/local_client_allocation_test.cc | 105 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/local_client_execute_test.cc | 618 |
13 files changed, 959 insertions, 121 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index b36933436c..a0fc230319 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -166,7 +166,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( backend.platform()->Name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run( @@ -225,7 +225,7 @@ tensorflow::Status LocalExecutable::RecordArguments( TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); *session_module->add_arguments() = literal.ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } tensorflow::Status LocalExecutable::RecordResult( @@ -234,7 +234,7 @@ tensorflow::Status LocalExecutable::RecordResult( Literal literal(session_module->result()); TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); *session_module->mutable_result() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } // TODO(dnovillo) Change signature to return StatusOr<Literal>. @@ -248,14 +248,6 @@ tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( shaped_buffer.shape(), literal); } -StatusOr<std::unique_ptr<GlobalData>> LocalClient::AllocateBufferOnDevice( - const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) { - TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, - local_service_->AllocateBufferOnDevice( - shape, device_ordinal, allocate_space_for_deep_copy)); - return std::unique_ptr<GlobalData>(new GlobalData(local_service_, handle)); -} - se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } @@ -297,4 +289,56 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile( device_ordinal, options)); } +// Copy the literal data to the device with the given ordinal and return as a +// ScopedShapedBuffer. The given memory allocator is used for device memory +// allocation. +StatusOr<std::unique_ptr<ScopedShapedBuffer>> +LocalClient::LiteralToShapedBuffer(const Literal& literal, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + ScopedShapedBuffer::MakeScopedShapedBuffer( + literal.shape(), allocator, device_ordinal)); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + // This is a leaf of the shape. Transfer the literal array data to the + // device buffer. + return backend().transfer_manager()->TransferLiteralToDevice( + executor, literal.GetSubliteral(index), + scoped_buffer->mutable_buffer(index)); + } + return Status::OK(); + })); + return std::move(scoped_buffer); +} + +// Copy the data from the device contained in the given ShapedBuffer and +// return as a Literal. +StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer) { + std::unique_ptr<Literal> literal = + Literal::CreateFromShape(shaped_buffer.shape()); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + backend().stream_executor(shaped_buffer.device_ordinal())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + literal->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + // This is a leaf of the shape. Transfer the device buffer into the + // literal. The layout of the literal and the device buffer are + // necessarily the same so we pass 'subshape' for both device and + // literal shapes. + return backend().transfer_manager()->TransferLiteralFromDevice( + executor, shaped_buffer.buffer(index), + /*device_shape=*/subshape, + /*literal_shape*/ subshape, &literal->GetSubliteral(index)); + } + return Status::OK(); + })); + return std::move(literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index c903cd2711..e98384238a 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -158,15 +158,6 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Return a handle to a buffer large enough to hold shape, allocated - // on device_ordinal on the local service. If - // allocate_space_for_deep_copy, the buffer is large enough to hold - // all sub-buffers of a tuple shape, otherwise it is only as large - // as the top-level tuple pointer array. - StatusOr<std::unique_ptr<GlobalData>> AllocateBufferOnDevice( - const Shape& shape, int device_ordinal, - bool allocate_space_for_deep_copy); - // Build and return a LocalExecutable object. The executable is compiled using // the given argument layouts and options. StatusOr<std::unique_ptr<LocalExecutable>> Compile( @@ -174,6 +165,18 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a + // ScopedShapedBuffer. The given memory allocator is used for device memory + // allocation. + StatusOr<std::unique_ptr<ScopedShapedBuffer>> LiteralToShapedBuffer( + const Literal& literal, DeviceMemoryAllocator* allocator, + int device_ordinal); + + // Copy the data from the device contained in the given ShapedBuffer and + // return as a Literal. + StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e6787361d4..b867cf42d1 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -36,6 +36,11 @@ limitations under the License. namespace xla { +std::ostream& operator<<(std::ostream& out, const Literal& literal) { + out << literal.ToString(); + return out; +} + Literal::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice<int64> dimensions) @@ -927,16 +932,16 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } // namespace -bool Literal::Equal(const Literal& literal2) const { - if (!ShapeUtil::Compatible(shape(), literal2.shape())) { +bool Literal::operator==(const Literal& other) const { + if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } if (ShapeUtil::IsTuple(shape())) { // Because the shapes are compatible, they must have the same number of // tuple elements. - CHECK_EQ(tuple_literals_size(), literal2.tuple_literals_size()); + CHECK_EQ(tuple_literals_size(), other.tuple_literals_size()); for (int i = 0; i < tuple_literals_size(); ++i) { - if (!tuple_literals(i).Equal(literal2.tuple_literals(i))) { + if (tuple_literals(i) != other.tuple_literals(i)) { return false; } } @@ -945,23 +950,23 @@ bool Literal::Equal(const Literal& literal2) const { std::vector<int64> multi_index(ShapeUtil::Rank(shape()), 0); switch (shape().element_type()) { case PRED: - return EqualElements<bool>(*this, literal2, 0, &multi_index); + return EqualElements<bool>(*this, other, 0, &multi_index); case U8: - return EqualElements<uint8>(*this, literal2, 0, &multi_index); + return EqualElements<uint8>(*this, other, 0, &multi_index); case S32: - return EqualElements<int32>(*this, literal2, 0, &multi_index); + return EqualElements<int32>(*this, other, 0, &multi_index); case S64: - return EqualElements<int64>(*this, literal2, 0, &multi_index); + return EqualElements<int64>(*this, other, 0, &multi_index); case U32: - return EqualElements<uint32>(*this, literal2, 0, &multi_index); + return EqualElements<uint32>(*this, other, 0, &multi_index); case U64: - return EqualElements<uint64>(*this, literal2, 0, &multi_index); + return EqualElements<uint64>(*this, other, 0, &multi_index); case F32: - return EqualElements<float>(*this, literal2, 0, &multi_index); + return EqualElements<float>(*this, other, 0, &multi_index); case F64: - return EqualElements<double>(*this, literal2, 0, &multi_index); + return EqualElements<double>(*this, other, 0, &multi_index); case F16: - return EqualElements<half>(*this, literal2, 0, &multi_index); + return EqualElements<half>(*this, other, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: Literal::Equal for type " << PrimitiveType_Name(shape().element_type()); @@ -1400,4 +1405,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { } } +const Literal& Literal::GetSubliteral(const ShapeIndex& index) const { + return const_cast<Literal*>(this)->GetSubliteral(index); +} + +Literal& Literal::GetSubliteral(const ShapeIndex& index) { + Literal* subliteral = this; + for (int64 i : index) { + subliteral = &subliteral->tuple_literals_.at(i); + } + return *subliteral; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 858ddc297b..e8cee732d4 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include <initializer_list> #include <iterator> #include <memory> +#include <ostream> #include <string> #include <type_traits> #include <vector> @@ -66,6 +67,11 @@ class Literal { Literal& operator=(const Literal& other) = default; Literal& operator=(Literal&&) = default; + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not checked. + bool operator==(const Literal& other) const; + bool operator!=(const Literal& other) const { return !(*this == other); } + LiteralProto ToProto() const; bool has_shape() const { @@ -77,6 +83,10 @@ class Literal { string DebugString() const { return ToProto().DebugString(); } string ShortDebugString() const { return ToProto().ShortDebugString(); } + // Return the nested literal at the given shape index. + const Literal& GetSubliteral(const ShapeIndex& index) const; + Literal& GetSubliteral(const ShapeIndex& index); + void Clear() { shape_.Clear(); u8s_.clear(); @@ -518,10 +528,6 @@ class Literal { template <typename NativeT> void Resize(int64 num_elements, NativeT value); - // Returns true if this literal has the same shape and value as the given - // literal. Layout is not considered in the comparison. - bool Equal(const Literal& literal2) const; - // Returns whether every element in this literal is equal to value. // // value is an int8 because we expect this to be called with small @@ -597,6 +603,8 @@ class Literal { std::vector<Literal> tuple_literals_; }; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + // Declarations of template specializations for GetArraySlice and // GetMutableArraySlice. The specializations map native type to XLA primitive // type. diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61ceac4f9a..e7dedd0821 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -257,37 +257,37 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { } TEST_F(LiteralUtilTest, ScalarEquality) { - // Test Literal::Equal with scalars. + // Test equality with scalars. auto f32_42 = Literal::CreateR0<float>(42.0); auto f32_42_clone = Literal::CreateR0<float>(42.0); - EXPECT_TRUE(f32_42->Equal(*f32_42)); - EXPECT_TRUE(f32_42->Equal(*f32_42_clone)); + EXPECT_EQ(*f32_42, *f32_42); + EXPECT_EQ(*f32_42, *f32_42_clone); auto f32_123 = Literal::CreateR0<float>(123.0); - EXPECT_FALSE(f32_42->Equal(*f32_123)); + EXPECT_NE(*f32_42, *f32_123); auto f64_42 = Literal::CreateR0<double>(42.0); - EXPECT_FALSE(f32_42->Equal(*f64_42)); + EXPECT_NE(*f32_42, *f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { - // Test Literal::Equal with nonscalars. + // Test equality with nonscalars. auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}}); auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0}); auto scalar = Literal::CreateR0<float>(1.0); - EXPECT_TRUE(matrix->Equal(*matrix)); - EXPECT_TRUE(matrix->Equal(*matrix_clone)); - EXPECT_FALSE(matrix->Equal(*matrix_different)); - EXPECT_FALSE(matrix->Equal(*vector_literal)); - EXPECT_FALSE(matrix->Equal(*scalar)); + EXPECT_EQ(*matrix, *matrix); + EXPECT_EQ(*matrix, *matrix_clone); + EXPECT_NE(*matrix, *matrix_different); + EXPECT_NE(*matrix, *vector_literal); + EXPECT_NE(*matrix, *scalar); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { - // Test Literal::Equal with literals which have different layouts. + // Test equality with literals which have different layouts. auto colmajor = MakeUnique<Literal>(); *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -306,11 +306,11 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) { rowmajor->Set<float>({1, 0}, 3.0); rowmajor->Set<float>({1, 1}, 4.0); - EXPECT_TRUE(rowmajor->Equal(*colmajor)); + EXPECT_EQ(*rowmajor, *colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { - // Test Literal::Equal with tuples. + // Test equality with tuples. auto scalar = Literal::CreateR0<float>(1.0); auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); @@ -319,16 +319,16 @@ TEST_F(LiteralUtilTest, TupleEquality) { // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = Literal::CreateR0<float>(1.0); auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_TRUE(tuple1->Equal(*tuple2)); + EXPECT_EQ(*tuple1, *tuple2); // Tuple with elements reversed. auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_FALSE(tuple1->Equal(*reversed_tuple)); + EXPECT_NE(*tuple1, *reversed_tuple); // Tuple with different value. auto scalar_42 = Literal::CreateR0<float>(42.0); auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_FALSE(tuple1->Equal(*different_tuple)); + EXPECT_NE(*tuple1, *different_tuple); } TEST_F(LiteralUtilTest, IsAllTuple) { @@ -348,7 +348,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_TRUE(tuple->Equal(*x)); + EXPECT_EQ(*tuple, *x); } TEST_F(LiteralUtilTest, IsAll) { @@ -439,17 +439,17 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { auto data01 = data->Relayout(layout01); EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_TRUE(data->Equal(*data01)); + EXPECT_EQ(*data, *data01); auto data10 = data->Relayout(layout10); EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_TRUE(data->Equal(*data10)); + EXPECT_EQ(*data, *data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0<float>(1.7f); auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); - EXPECT_TRUE(original->Equal(*reshape)); + EXPECT_EQ(*original, *reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -469,7 +469,7 @@ TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format on auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(expected->Equal(*reshape)); + EXPECT_EQ(*expected, *reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -489,13 +489,13 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format on auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(expected->Equal(*reshape)); + EXPECT_EQ(*expected, *reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = Literal::CreateR0<float>(1.7f); auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_TRUE(original->Equal(*reshape)); + EXPECT_EQ(*original, *reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -521,13 +521,11 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // target layout in the first place. auto dim0minor_relaid_to_dim0major = literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_TRUE( - literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major)); + EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_TRUE( - literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor)); + EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { @@ -596,14 +594,14 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { TEST_F(LiteralUtilTest, SliceR0S32) { auto input = Literal::CreateR0<int32>(1); auto result = input->Slice({}, {}); - EXPECT_TRUE(input->Equal(*result)); + EXPECT_EQ(*input, *result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0}); auto result = input->Slice({3}, {4}); auto expected = Literal::CreateR1<float>({4.0}); - EXPECT_TRUE(expected->Equal(*result)); + EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR2U32) { @@ -611,49 +609,49 @@ TEST_F(LiteralUtilTest, SliceR2U32) { Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto result = input_3x4->Slice({0, 2}, {2, 4}); auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}}); - EXPECT_TRUE(expected->Equal(*result)); + EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = Literal::CreateR3<uint32>( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_TRUE(input_2x3x2->Equal(*result)); + EXPECT_EQ(*input_2x3x2, *result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output; output.PopulateR1<int64>({77}); auto expected = Literal::CreateR1<int64>({77}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR2U64) { Literal output; output.PopulateR1<uint64>({{77, 88}}); auto expected = Literal::CreateR1<uint64>({{77, 88}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue<float>(2.5f, {}); auto expected = Literal::CreateR0<float>(2.5f); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output; output.PopulateWithValue<int64>(-7, {3}); auto expected = Literal::CreateR1<int64>({-7, -7, -7}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output; output.PopulateWithValue<uint64>(42, {2, 2}); auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -661,7 +659,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue<half>(h, {}); auto expected = Literal::CreateR0<half>(h); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -669,7 +667,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue<half>(h, {3}); auto expected = Literal::CreateR1<half>({h, h, h}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -677,7 +675,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue<half>(h, {2, 2}); auto expected = Literal::CreateR2<half>({{h, h}, {h, h}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { @@ -688,7 +686,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_TRUE(output->Equal(*expected)); + EXPECT_EQ(*output, *expected); } TEST_F(LiteralUtilTest, Copy) { @@ -741,7 +739,7 @@ TEST_F(LiteralUtilTest, CopyScalars) { auto zero = Literal::CreateR0<uint32>(0); auto nine = Literal::CreateR0<uint32>(9); TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); - EXPECT_TRUE(zero->Equal(*nine)); + EXPECT_EQ(*zero, *nine); auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21}); TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); @@ -761,7 +759,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { auto nine = Literal::CreateR1<float>({9}); TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); - EXPECT_TRUE(nine->Equal(*const_nine)); + EXPECT_EQ(*nine, *const_nine); } { @@ -770,7 +768,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { auto nine = Literal::CreateR1<float>({9}); TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); - EXPECT_TRUE(empty->Equal(*const_empty)); + EXPECT_EQ(*empty, *const_empty); } } @@ -863,7 +861,7 @@ TEST_F(LiteralUtilTest, ConvertR4) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted, original->Convert(U32)); - EXPECT_TRUE(expected->Equal(*converted)); + EXPECT_EQ(*expected, *converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -925,43 +923,43 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { std::unique_ptr<Literal> conv; conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*u32)); + EXPECT_EQ(*conv, *u32); conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*u64)); + EXPECT_EQ(*conv, *u64); conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s64)); + EXPECT_EQ(*conv, *s64); conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*pred)); + EXPECT_EQ(*conv, *pred); conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*int32_pred)); + EXPECT_EQ(*conv, *int32_pred); conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f32)); + EXPECT_EQ(*conv, *f32); conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); @@ -1045,5 +1043,25 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, Subliterals) { + auto scalar = Literal::CreateR0<float>(1.0); + auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + + EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get()); + EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get()); + EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get()); + EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get()); + + EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar); + EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix); + + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 74f8e3143d..f7551bfb6c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1376,7 +1376,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - !pad_value->literal().Equal(reduce_init_value->literal())) { + pad_value->literal() != reduce_init_value->literal()) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 0fef89a06d..cdccacdd2d 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -68,7 +68,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal().Equal(it->second->literal())) { + if (instruction->literal() == it->second->literal()) { match = it->second; break; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3366b83fdd..a883429341 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1293,7 +1293,7 @@ bool HloInstruction::IdenticalSlowPath( // A constant is defined by the value in the literal. case HloOpcode::kConstant: - return literal().Equal(other.literal()); + return literal() == other.literal(); // A convert result is determined by the primitive type that the operand is // converted into. diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 1eb4edbe3e..3235081f83 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -88,21 +88,6 @@ int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, } } // namespace -StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice( - const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) { - int64 allocation_size = RequiredSpace(shape, allocate_space_for_deep_copy, - execute_backend_->transfer_manager()); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - device_ordinal, allocation_size)); - - return allocation_tracker_.Register( - execute_backend_.get(), device_ordinal, allocation, shape, - tensorflow::strings::StrCat("AllocateBufferOnDevice of size ", - allocation_size)); -} - StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index c90943f3c0..f2bfb960f4 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -39,14 +39,6 @@ class LocalService : public Service { static StatusOr<std::unique_ptr<LocalService>> NewService( const ServiceOptions& options); - // Return a handle to a buffer large enough to hold shape, allocated - // on device_ordinal. If allocate_space_for_deep_copy, the buffer is - // large enough to hold all sub-buffers of a tuple shape, otherwise - // it is only as large as the top-level tuple pointer array. - StatusOr<GlobalDataHandle> AllocateBufferOnDevice( - const Shape& shape, int device_ordinal, - bool allocate_space_for_deep_copy); - // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a // result of the given layout. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 33537877ea..e45b839afd 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1439,6 +1439,54 @@ tf_cc_test( ], ) +xla_test( + name = "local_client_allocation_test", + srcs = ["local_client_allocation_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:local_client_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "local_client_execute_test", + srcs = ["local_client_execute_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:local_client_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "hlo_metadata_test", srcs = [ diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc new file mode 100644 index 0000000000..6897f0291a --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class LocalClientAllocationTest : public LocalClientTestBase { + protected: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(LocalClientAllocationTest, AddVectors) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f}); + auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + builder.Add(x, y); + + TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f})); + + int64 allocation_count_before = allocator_->allocation_count(); + + // Override the allocator via 'options'. Tests that allocation and + // deallocation happen on the right allocator. + ExecutableRunOptions options; + options.set_allocator(allocator); + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}, + DefaultExecutableBuildOptions(), options); + + LiteralTestUtil::ExpectR1Near<float>( + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + + // At least one allocation should have been performed when executing the + // computation. + EXPECT_GT(allocator_->allocation_count(), allocation_count_before); + + // Deallocate result and verify that deallocate was called once. + int64 deallocation_count_before = allocator_->deallocation_count(); + result = nullptr; + EXPECT_EQ(deallocation_count_before + 1, allocator_->deallocation_count()); +} + +XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { + // Run a computation on every device on the system. Verify that allocation + // occurs on the proper device. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f}); + auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + builder.Add(x, y); + auto computation = builder.Build().ConsumeValueOrDie(); + + TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); + for (int d = 0; d < local_client_->device_count(); ++d) { + if (!local_client_->device_ordinal_supported(d)) { + continue; + } + + int64 device_allocation_count_before = allocator->allocation_count(d); + int64 allocation_count_before = allocator->allocation_count(); + + auto result = ExecuteLocallyOrDie( + computation, {}, ExecutableBuildOptions().set_device_ordinal(d), + ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); + LiteralTestUtil::ExpectR1Near<float>( + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + + // At least one allocation should have been performed when executing the + // computation. + EXPECT_GT(allocator->allocation_count(), allocation_count_before); + EXPECT_GT(allocator->allocation_count(d), device_allocation_count_before); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc new file mode 100644 index 0000000000..ef2592e292 --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -0,0 +1,618 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <initializer_list> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace { + +using ::testing::ContainsRegex; + +class LocalClientExecuteTest : public LocalClientTestBase { + protected: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(LocalClientExecuteTest, Constant) { + ComputationBuilder builder(local_client_, TestName()); + auto y = builder.ConstantR0<float>(123.0f); + + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); + + LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(*result), + error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, AddScalars) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.ConstantR0<float>(123.0f); + builder.Add(x, y); + + auto x_value = LiteralToScopedShapedBuffer(*Literal::CreateR0<float>(42.0f)); + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_value.get()}); + + LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(*result), + error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); + auto y = builder.ConstantR1<float>({}); + builder.Add(x, y); + + auto x_array = LiteralToScopedShapedBuffer(*Literal::CreateR1<float>({})); + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); + + LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(*result), + error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, AddVectors) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + builder.Add(x, y); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f})); + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); + + LiteralTestUtil::ExpectR1Near<float>( + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + builder.Add(x, y); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f})); + ExecutionProfile profile; + std::unique_ptr<ScopedShapedBuffer> result = ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), {x_array.get()}, + DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_execution_profile(&profile)); + + LiteralTestUtil::ExpectR1Near<float>( + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); +} + +XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + builder.Add(x, y); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Create x as a col-major array. + auto x_array = LiteralToScopedShapedBuffer( + *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*minor_to_major=*/{0, 1})); + EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), + LayoutUtil::MakeLayout({0, 1}))); + + // Create y as a row-major array. + auto y_array = LiteralToScopedShapedBuffer( + *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, + /*minor_to_major=*/{1, 0})); + EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), + LayoutUtil::MakeLayout({1, 0}))); + + std::unique_ptr<ScopedShapedBuffer> result_colmaj = + ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, + *ShapedBufferToLiteral(*result_colmaj), + error_spec_); + + // Run with the parameter values in a different order. + std::unique_ptr<ScopedShapedBuffer> result_param_swap = + ExecuteLocallyOrDie(computation, {y_array.get(), x_array.get()}); + LiteralTestUtil::ExpectR2Near<float>( + {{11.0f, 22.0f}, {33.0f, 44.0f}}, + *ShapedBufferToLiteral(*result_param_swap), error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + builder.Add(x, y); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); + auto y_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); + + // Run with col-major result layout. + std::unique_ptr<ScopedShapedBuffer> result_colmaj = ExecuteLocallyOrDie( + computation, {x_array.get(), y_array.get()}, + DefaultExecutableBuildOptions().set_result_layout( + ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), + DefaultExecutableRunOptions()); + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(), + LayoutUtil::MakeLayout({0, 1}))); + LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, + *ShapedBufferToLiteral(*result_colmaj), + error_spec_); + + // Run with row-major result layout. + std::unique_ptr<ScopedShapedBuffer> result_rowmaj = ExecuteLocallyOrDie( + computation, {x_array.get(), y_array.get()}, + DefaultExecutableBuildOptions().set_result_layout( + ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), + DefaultExecutableRunOptions()); + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(), + LayoutUtil::MakeLayout({1, 0}))); + LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, + *ShapedBufferToLiteral(*result_rowmaj), + error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, TupleResult) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + builder.Tuple({x, y, x}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); + auto y_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); + + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + + EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape())); + + std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + result_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}, + result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + result_literal->tuple_literals(2)); +} + +XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + auto inner_tuple = builder.Tuple({x, y, x}); + builder.Tuple({inner_tuple, x}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); + auto y_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); + + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + + EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + + std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + result_literal->tuple_literals(1)); + const Literal& inner_tuple_literal = result_literal->tuple_literals(0); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + inner_tuple_literal.tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}, + inner_tuple_literal.tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + inner_tuple_literal.tuple_literals(2)); +} + +XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { + // Verify setting the result layout of a computation with a tuple output. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + builder.Tuple({x, y}); + + auto array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); + + ExecutableBuildOptions options = DefaultExecutableBuildOptions(); + Shape shape_with_layout = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{0, 1}), + ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{1, 0})}); + options.set_result_layout(shape_with_layout); + std::unique_ptr<ScopedShapedBuffer> result = ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), {array.get(), array.get()}, options, + DefaultExecutableRunOptions()); + + std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + result_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, + result_literal->tuple_literals(1)); +} + +XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { + // Test passing in an invalid number of arguments. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); + builder.Add(x, y); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR1<float>({1.0f, 2.0f, 3.0f})); + auto execute_status = + ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("invalid number of arguments")); +} + +XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { + // Test passing in an argument with the wrong shape. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); + builder.Neg(x); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}})); + auto execute_status = + ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("invalid argument shape")) + << execute_status.status(); +} + +XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { + // Test passing in an invalid result layout parameter. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + builder.Neg(x); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}})); + auto execute_status = ExecuteLocally( + builder.Build().ValueOrDie(), {x_array.get()}, + DefaultExecutableBuildOptions().set_result_layout( + ShapeUtil::MakeShapeWithLayout(F32, + /*dimensions=*/{1, 2, 3, 4}, + /*minor_to_major=*/{0, 1, 2, 3})), + DefaultExecutableRunOptions()); + + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("not compatible with result shape")) + << execute_status.status(); +} + +XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { + // Try to run a trivial computation on every device on the system. If a + // specific device is not supported, check that the right error is returned. + ComputationBuilder builder(local_client_, TestName()); + builder.ConstantR0<float>(42.0f); + auto computation = builder.Build().ConsumeValueOrDie(); + for (int d = 0; d < local_client_->device_count(); ++d) { + if (!local_client_->device_ordinal_supported(d)) { + auto execute_status = + ExecuteLocally(computation, {}, + DefaultExecutableBuildOptions().set_device_ordinal(d), + DefaultExecutableRunOptions().set_device_ordinal(d)); + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("device .* not supported")); + } else { + auto result = ExecuteLocallyOrDie( + computation, {}, + DefaultExecutableBuildOptions().set_device_ordinal(d), + DefaultExecutableRunOptions().set_device_ordinal(d)); + EXPECT_EQ(d, result->device_ordinal()); + LiteralTestUtil::ExpectR0Equal<float>(42.0f, + *ShapedBufferToLiteral(*result)); + } + } +} + +XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { + // Try running computations on devices with device ordinal values which do not + // exist. + ComputationBuilder builder(local_client_, TestName()); + builder.ConstantR0<float>(42.0f); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto execute_status = + ExecuteLocally(computation, {}, + DefaultExecutableBuildOptions().set_device_ordinal( + local_client_->device_count()), + DefaultExecutableRunOptions().set_device_ordinal( + local_client_->device_count())); + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("Invalid device ordinal value")); +} + +XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { + // Run a computation on a specific stream on each device on the system. + ComputationBuilder builder(local_client_, TestName()); + builder.ConstantR0<float>(42.0f); + auto computation = builder.Build().ConsumeValueOrDie(); + + for (int d = 0; d < local_client_->device_count(); ++d) { + if (!local_client_->device_ordinal_supported(d)) { + continue; + } + se::StreamExecutor* executor = + local_client_->platform()->ExecutorForDevice(d).ValueOrDie(); + se::Stream stream(executor); + stream.Init(); + + auto result = + ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_stream(&stream)); + // As a check to verify that the computation ran of the device associated + // with the stream. This is a weak check, but stronger verification is hard. + EXPECT_EQ(d, result->device_ordinal()); + LiteralTestUtil::ExpectR0Equal<float>(42.0f, + *ShapedBufferToLiteral(*result)); + } +} + +// Disable this test on CPU because we're using the CPU as the platform +// which does not match the service platform. +XLA_TEST_F(LocalClientExecuteTest, + DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) { + // Try to run a computation on a stream for a platform (CPU) which does not + // match the platform of the service (!= CPU). + se::Platform* wrong_platform = + se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) + .ValueOrDie(); + se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie()); + wrong_stream.Init(); + + ComputationBuilder builder(local_client_, TestName()); + builder.ConstantR0<float>(42.0f); + auto execute_status = ExecuteLocally( + builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_stream(&wrong_stream)); + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("stream is for platform .*, but service targets")); +} + +XLA_TEST_F(LocalClientExecuteTest, + DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) { + se::Platform* wrong_platform = + se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) + .ValueOrDie(); + TestAllocator allocator(wrong_platform); + + ComputationBuilder builder(local_client_, TestName()); + auto y = builder.ConstantR0<float>(123.0f); + + auto execute_status = ExecuteLocally( + builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_allocator(&allocator)); + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("allocator platform .* does not match service")); +} + +XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { + // Try to run a computation on a stream that has not been initialized. + ComputationBuilder builder(local_client_, TestName()); + builder.ConstantR0<float>(42.0f); + + LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); + se::StreamExecutor* executor = + local_client_->platform() + ->ExecutorForDevice(local_client_->default_device_ordinal()) + .ValueOrDie(); + se::Stream stream(executor); + // Don't call stream.Init(). + + auto execute_status = ExecuteLocally( + builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_stream(&stream)); + EXPECT_FALSE(execute_status.ok()); + EXPECT_THAT(execute_status.status().error_message(), + ContainsRegex("stream is uninitialized or in an error state")); +} + +XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { + ComputationBuilder builder(local_client_, TestName()); + + std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; + std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; + auto tuple12 = builder.Tuple( + {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)}); + auto tuple21 = builder.Tuple( + {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)}); + builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21); + + std::unique_ptr<ScopedShapedBuffer> result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); + std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f}, + tuple_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f}, + tuple_literal->tuple_literals(1)); +} + +XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + builder.Add(x, y); + + Shape argument_layout = + ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); + auto executable_status = + local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout}, + ExecutableBuildOptions()); + ASSERT_IS_OK(executable_status); + std::unique_ptr<LocalExecutable> executable = + executable_status.ConsumeValueOrDie(); + + auto x_array = LiteralToScopedShapedBuffer( + *Literal::CreateR1<float>({0.0f, 1.0f, 2.0f})); + std::unique_ptr<ScopedShapedBuffer> result = ShapedBufferToScopedShapedBuffer( + executable->Run({x_array.get()}, DefaultExecutableRunOptions()) + .ConsumeValueOrDie(), + allocator_); + + LiteralTestUtil::ExpectR1Near<float>( + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); +} + +XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { + // Test copying Literals to the device as ShapedBuffers, then copying them + // back again to Literals. + auto test_to_device_and_back = [this](const Literal& literal) { + TF_ASSERT_OK_AND_ASSIGN( + auto shaped_buffer, + local_client_->LiteralToShapedBuffer( + literal, allocator_, local_client_->default_device_ordinal())); + TF_ASSERT_OK_AND_ASSIGN( + auto transferred_literal, + local_client_->ShapedBufferToLiteral(*shaped_buffer)); + EXPECT_EQ(literal, *transferred_literal); + }; + + // Array shapes. + test_to_device_and_back(*Literal::CreateR0<float>(42.0)); + test_to_device_and_back(*Literal::CreateR0<bool>(true)); + test_to_device_and_back(*Literal::CreateR1<float>({1.0, 42.0, 744.4})); + test_to_device_and_back( + *Literal::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*Literal::CreateR2<int32>({{2, 1}, {4444, 56}})); + + // Null shape (empty tuple). + test_to_device_and_back(*Literal::MakeTuple({})); + + // Non-nested tuples. + test_to_device_and_back( + *Literal::MakeTuple({Literal::CreateR0<float>(12223.0).get()})); + test_to_device_and_back( + *Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(), + Literal::CreateR0<float>(123456.0).get()})); + + // Nested tuple. + test_to_device_and_back(*Literal::MakeTuple( + {Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(), + Literal::CreateR0<float>(123456.0).get()}) + .get(), + Literal::CreateR0<bool>(false).get()})); +} + +// Benchmark that measures the overhead of the LocalClient API when running a +// trivial computation +void BM_LocalClientOverhead(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + LocalClient* client = + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + // Use a tiny add operation as the computation. + ComputationBuilder builder(client, "Add"); + auto shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto x = builder.Parameter(0, shape, "x"); + builder.Add(x, x); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(shape, &allocator, 0) + .ConsumeValueOrDie(); + auto literal = Literal::CreateR2<float>({{0, 0, 0}, {0, 0, 0}}); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *literal, buffer->mutable_buffer({}))); + + const int kWarmups = 2; + + auto executable_status = client->Compile(computation, {&buffer->shape()}, + ExecutableBuildOptions()); + ASSERT_IS_OK(executable_status); + std::unique_ptr<LocalExecutable> executable = + executable_status.ConsumeValueOrDie(); + + se::Stream stream(executors[client->default_device_ordinal()]); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_allocator(&allocator).set_stream(&stream); + + for (int i = 0; i < kWarmups; ++i) { + auto result = executable->Run({buffer.get()}, run_options); + ASSERT_IS_OK(result); + } + + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = executable->Run({buffer.get()}, run_options); + ASSERT_IS_OK(result); + } +} + +BENCHMARK(BM_LocalClientOverhead); + +} // namespace +} // namespace xla |