aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-19 06:01:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 06:04:39 -0700
commitbae4a271c036e6ede7cab6f4328b0a7966ef9fd4 (patch)
treebdf720b23271704f4f10bdfe6a752fe0b6cc3ec4
parent707ac111cfed90f35c37417d8c79ab7cbcba152a (diff)
Internal change
PiperOrigin-RevId: 201161803
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc8
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc5
-rw-r--r--tensorflow/compiler/xla/service/executable.h3
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc45
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc14
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc8
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.cc2
-rw-r--r--tensorflow/compiler/xla/service/service.cc42
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc139
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h71
-rw-r--r--tensorflow/compiler/xla/shape_util.cc8
-rw-r--r--tensorflow/compiler/xla/shape_util.h3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc100
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc258
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc1
20 files changed, 520 insertions, 238 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 71e63b110b..37005479dc 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -74,7 +74,7 @@ Status XlaTransferManager::TransferLiteralToDevice(
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< shaped_buffer.ToString();
- return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
+ return transfer_manager_->TransferLiteralToDevice(stream_, literal,
shaped_buffer);
}
@@ -83,9 +83,9 @@ Status XlaTransferManager::TransferLiteralFromDevice(
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
- transfer_manager_->TransferLiteralFromDevice(
- stream_->parent(), shaped_buffer));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<xla::Literal> literal,
+ transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
Tensor tensor;
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index ae0308020d..cf07910c4a 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -230,10 +230,9 @@ Status LocalExecutable::RecordResult(const ShapedBuffer* result,
StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- backend_->stream_executor(shaped_buffer.device_ordinal()));
- return backend_->transfer_manager()->TransferLiteralFromDevice(executor,
+ TF_ASSIGN_OR_RETURN(auto stream,
+ backend_->BorrowStream(shaped_buffer.device_ordinal()));
+ return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(),
shaped_buffer);
}
@@ -288,19 +287,18 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
TF_ASSIGN_OR_RETURN(auto scoped_buffer,
backend().transfer_manager()->AllocateScopedShapedBuffer(
literal.shape(), allocator, device_ordinal));
- TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
- backend().stream_executor(device_ordinal));
+ TF_ASSIGN_OR_RETURN(auto stream,
+ mutable_backend()->BorrowStream(device_ordinal));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- executor, literal, scoped_buffer));
+ stream.get(), literal, scoped_buffer));
return std::move(scoped_buffer);
}
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- backend().stream_executor(shaped_buffer.device_ordinal()));
- return backend().transfer_manager()->TransferLiteralFromDevice(executor,
+ TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
+ shaped_buffer.device_ordinal()));
+ return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
shaped_buffer);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index d97802ee45..b877b29581 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -160,9 +160,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
int32 size_32 = static_cast<int32>(size);
CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32);
- Status s =
- TransferBufferToDevice(executor, /*size=*/size,
- /*source=*/source, queued_buffer->device_memory());
+ Status s = executor->SynchronousMemcpyH2D(
+ /*host_src=*/source, /*size=*/size, queued_buffer->device_memory());
if (!s.ok()) {
queued_buffer->Done(s);
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index dc1f26ea65..1a91aca9d1 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -88,8 +88,7 @@ class Executable {
// called explicitly for other (async, for example) variants after the stream
// has completed.
virtual Status PopulateExecutionProfile(
- HloExecutionProfile* hlo_execution_profile,
- se::StreamExecutor* executor) {
+ HloExecutionProfile* hlo_execution_profile, se::Stream* stream) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index d9f62c21c4..85e28a0dfe 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -43,7 +43,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const {
}
Status GenericTransferManager::WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) {
TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
@@ -52,12 +52,24 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
for (const se::DeviceMemoryBase& element : elements) {
element_pointers.push_back(element.opaque());
}
- return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
- element_pointers.data(), region);
+ TF_RETURN_IF_ERROR(TransferBufferToDevice(
+ stream, GetByteSizeRequirement(shape), element_pointers.data(), region));
+ // Ensure the buffer is transferred before we destroy element_pointers.
+ return stream->BlockHostUntilDone();
+}
+
+void GenericTransferManager::TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ Status status = stream->BlockHostUntilDone();
+ if (!status.ok()) {
+ return done(status);
+ }
+ done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
}
StatusOr<std::unique_ptr<Literal>>
-GenericTransferManager::TransferLiteralFromDevice(
+GenericTransferManager::TransferLiteralFromDeviceInternal(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
@@ -75,8 +87,7 @@ GenericTransferManager::TransferLiteralFromDevice(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
if (ShapeUtil::IsArray(subshape)) {
- TF_RETURN_IF_ERROR(TransferBufferFromDevice(
- executor,
+ TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H(
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
@@ -88,8 +99,8 @@ GenericTransferManager::TransferLiteralFromDevice(
return std::move(literal);
}
-Status GenericTransferManager::TransferLiteralToDevice(
- se::StreamExecutor* executor, const LiteralSlice& literal,
+Status GenericTransferManager::TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) {
const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: "
@@ -103,9 +114,10 @@ Status GenericTransferManager::TransferLiteralToDevice(
TF_RET_CHECK(
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
- TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+ TF_RET_CHECK(stream->parent()->device_ordinal() ==
+ device_buffer.device_ordinal());
- TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer));
+ TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer));
return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
@@ -121,16 +133,21 @@ Status GenericTransferManager::TransferLiteralToDevice(
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
source = subliteral.untyped_data();
+ return TransferBufferToDevice(
+ stream,
+ /*size=*/GetByteSizeRequirement(device_subshape), source,
+ &device_memory);
} else {
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
source = relayed_out_literal->untyped_data();
+ TF_RETURN_IF_ERROR(TransferBufferToDevice(
+ stream,
+ /*size=*/GetByteSizeRequirement(device_subshape), source,
+ &device_memory));
+ return stream->BlockHostUntilDone();
}
- return TransferBufferToDevice(
- executor,
- /*size=*/GetByteSizeRequirement(device_subshape), source,
- &device_memory);
}
return Status::OK();
});
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 3da9570ef7..d216fe7d29 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -41,12 +41,13 @@ class GenericTransferManager : public TransferManager {
se::Platform::Id PlatformId() const override;
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override;
+ void TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override;
- Status TransferLiteralToDevice(se::StreamExecutor* executor,
- const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) override;
+ Status TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
@@ -64,11 +65,14 @@ class GenericTransferManager : public TransferManager {
const void* source) override;
Status WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
+ StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer);
+
// The platform this transfer manager targets.
const se::Platform::Id platform_id_;
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index e1f9d8efd4..4f0569f405 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -98,8 +98,10 @@ StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
backend().transfer_manager()->AllocateScopedShapedBuffer(
literal.shape(), backend().memory_allocator(),
backend().default_device_ordinal()));
+ TF_ASSIGN_OR_RETURN(
+ auto stream, backend().BorrowStream(backend().default_stream_executor()));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- backend().default_stream_executor(), literal, buffer));
+ stream.get(), literal, buffer));
return std::move(buffer);
}
@@ -127,8 +129,10 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
- return backend().transfer_manager()->TransferLiteralFromDevice(
- backend().default_stream_executor(), buffer);
+ TF_ASSIGN_OR_RETURN(
+ auto stream, backend().BorrowStream(backend().default_stream_executor()));
+ return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
+ buffer);
}
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
@@ -237,7 +241,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
backend().transfer_manager()->AllocateScopedShapedBuffer(
argument->shape(), backend().memory_allocator(), device));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- executor, *argument, argument_buffer));
+ streams.back().get(), *argument, argument_buffer));
argument_buffers.push_back(std::move(argument_buffer));
argument_buffer_ptrs[index++] = &argument_buffers.back();
}
@@ -307,7 +311,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
backend().transfer_manager()->TransferLiteralFromDevice(
- streams[i]->parent(), results[i]));
+ streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
}
return std::move(exec_results);
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 029e71058a..9816acf650 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -75,9 +75,9 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// consumes.
std::vector<std::unique_ptr<Literal>> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> arg_literal,
- transfer_manager->TransferLiteralFromDevice(executor, *arguments[p]));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ transfer_manager->TransferLiteralFromDevice(
+ run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
@@ -96,7 +96,7 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
result_literal->shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- executor, *result_literal, result));
+ run_options->stream(), *result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc
index 97e9fa2c8e..4fb67bd0b7 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executor.cc
@@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst,
AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
});
+ AsExecutorStream(stream)->BlockUntilDone();
return true;
}
@@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
});
+ AsExecutorStream(stream)->BlockUntilDone();
return true;
}
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ff68d65fbc..7ab39e01f2 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -64,25 +64,25 @@ namespace {
// Records the arguments used to invoke a computation in an HloSnapshot proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::StreamExecutor* executor, TransferManager* transfer_manager,
+ se::Stream* stream, TransferManager* transfer_manager,
HloSnapshot* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, *argument));
+ transfer_manager->TransferLiteralFromDevice(stream, *argument));
*module->add_arguments() = literal->ToProto();
}
return Status::OK();
}
// Records the result of a computation in a HloSnapshot proto.
-Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
+Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, result));
+ transfer_manager->TransferLiteralFromDevice(stream, result));
*module->mutable_result() = literal->ToProto();
return Status::OK();
}
@@ -496,7 +496,7 @@ Service::ExecuteParallelAndRegisterResult(
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
- executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
+ executable->PopulateExecutionProfile(&hlo_profile, stream));
XLA_LOG_LINES(
tensorflow::INFO,
hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));
@@ -721,8 +721,10 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
for (int i = 0; i < executable_ptrs.size(); i++) {
if (executable_ptrs[i]->dumping_snapshot()) {
- TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(),
- all_executors[i][0],
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(
+ all_executors[i][0]->device_ordinal()));
+ TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
execute_backend_->transfer_manager(),
executable_ptrs[i]->hlo_snapshot()));
}
@@ -747,7 +749,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
if (executable_ptrs[i]->dumping_snapshot()) {
TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
allocation_tracker_.ResolveForReplica(outputs[i], 0));
- TF_RETURN_IF_ERROR(RecordResult(*result_buffer, all_executors[i][0],
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(all_executors[i][0]));
+ TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
execute_backend_->transfer_manager(),
executable_ptrs[i]->hlo_snapshot()));
// Dump out the ith snapshot.
@@ -895,12 +899,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr));
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(
+ execute_backend_->default_stream_executor()));
if (executable->dumping_snapshot()) {
executable->hlo_snapshot()->set_execution_platform(
execute_backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(
- replicated_arguments.front(),
- execute_backend_->default_stream_executor(),
+ replicated_arguments.front(), stream.get(),
execute_backend_->transfer_manager(), executable->hlo_snapshot()));
}
@@ -914,9 +920,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
TF_ASSIGN_OR_RETURN(
const ShapedBuffer* result_buffer,
allocation_tracker_.ResolveForReplica(result->output(), 0));
- TF_RETURN_IF_ERROR(RecordResult(
- *result_buffer, execute_backend_->default_stream_executor(),
- execute_backend_->transfer_manager(), executable->hlo_snapshot()));
+ TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
+ execute_backend_->transfer_manager(),
+ executable->hlo_snapshot()));
TF_RETURN_IF_ERROR(executable->DumpHloSnapshot());
}
@@ -954,14 +960,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
return_shape = &shaped_buffer->on_host_shape();
}
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- execute_backend_->stream_executor(shaped_buffer->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
+ shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
- executor, *shaped_buffer));
+ stream.get(), *shaped_buffer));
if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
result_literal->shape())) {
@@ -1011,9 +1016,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
shape, execute_backend_->memory_allocator(),
executor->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- executor, *literal, shaped_buffer));
+ stream.get(), *literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index c4d01562c4..4c5038a009 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -22,8 +22,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/notification.h"
+
+using ::tensorflow::strings::StrCat;
namespace xla {
/* static */ tensorflow::mutex
@@ -36,8 +40,73 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
+StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ StatusOr<std::unique_ptr<Literal>> ret;
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+
+ tensorflow::Notification n;
+ TransferLiteralFromDevice(substream, device_buffer,
+ [&](StatusOr<std::unique_ptr<Literal>> arg) {
+ ret = std::move(arg);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
+Status TransferManager::TransferLiteralToDevice(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDeviceAsync(substream, literal, device_buffer));
+ return substream->BlockHostUntilDone();
+}
+
+StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ StatusOr<std::unique_ptr<Literal>> ret;
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+
+ tensorflow::Notification n;
+ TransferArrayFromDevice(substream, shape, source,
+ [&](StatusOr<std::unique_ptr<Literal>> arg) {
+ ret = std::move(arg);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
Status TransferManager::TransferArrayToDevice(
- se::StreamExecutor* executor, const LiteralSlice& literal,
+ se::Stream* stream, const LiteralSlice& literal,
+ const se::DeviceMemoryBase& dest) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+ TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest));
+ return substream->BlockHostUntilDone();
+}
+
+Status TransferManager::TransferArrayToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest) {
const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape))
@@ -51,28 +120,32 @@ Status TransferManager::TransferArrayToDevice(
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
- executor->platform(), executor->device_ordinal());
+ stream->parent()->platform(),
+ stream->parent()->device_ordinal());
shaped_buffer.set_buffer(dest, /*index=*/{});
- return TransferLiteralToDevice(executor, literal, shaped_buffer);
+ return TransferLiteralToDevice(stream, literal, shaped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
- se::StreamExecutor* executor, const Shape& shape,
- const se::DeviceMemoryBase& source) {
- TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape))
- << "Shape " << ShapeUtil::HumanString(shape)
- << " has a differently shaped representation on-device: "
- << ShapeUtil::HumanString(HostShapeToDeviceShape(shape));
+void TransferManager::TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) {
+ auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
+ " has a differently shaped representation on-device: ",
+ ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
+ return done(FailedPrecondition("%s", error.c_str()));
+ }
if (source.size() < GetByteSizeRequirement(shape)) {
- return FailedPrecondition(
- "Allocation on device not large enough for array: "
- "%lld < %lld",
- source.size(), GetByteSizeRequirement(shape));
+ return done(
+ FailedPrecondition("Allocation on device not large enough for array: "
+ "%lld < %lld",
+ source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
- executor->platform(), executor->device_ordinal());
+ stream->parent()->platform(),
+ stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{});
- return TransferLiteralFromDevice(executor, shaped_buffer);
+ return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done));
}
/* static */ void TransferManager::RegisterTransferManager(
@@ -108,10 +181,14 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
}
Status TransferManager::WriteTupleIndexTables(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
- VLOG(2) << "Writing tuple index tables for " << device_buffer;
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
+ return stream->BlockHostUntilDone();
+}
- TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+Status TransferManager::WriteTupleIndexTablesAsync(
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ VLOG(2) << "Writing tuple index tables for " << device_buffer;
return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_device_shape(),
@@ -129,7 +206,7 @@ Status TransferManager::WriteTupleIndexTables(
elements.push_back(device_buffer.buffer(element_index));
element_index.pop_back();
}
- return WriteSingleTupleIndexTable(executor, elements, device_subshape,
+ return WriteSingleTupleIndexTable(stream, elements, device_subshape,
&device_memory);
}
@@ -138,26 +215,20 @@ Status TransferManager::WriteTupleIndexTables(
}
Status TransferManager::TransferBufferFromDevice(
- se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
- int64 size, void* destination) {
+ se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
+ void* destination) {
if (source.size() < size) {
return FailedPrecondition(
"Source allocation on device not large enough for data tranfer: "
"%lld < %lld",
source.size(), size);
}
- auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination);
- if (!copy_status.ok()) {
- return AddStatus(
- Status(static_cast<tensorflow::error::Code>(copy_status.code()),
- copy_status.error_message()),
- "failed transfer from device to buffer");
- }
+ stream->ThenMemcpy(destination, source, size);
return Status::OK();
}
Status TransferManager::TransferBufferToDevice(
- se::StreamExecutor* executor, int64 size, const void* source,
+ se::Stream* stream, int64 size, const void* source,
se::DeviceMemoryBase* destination) {
if (destination->size() < size) {
return FailedPrecondition(
@@ -165,13 +236,7 @@ Status TransferManager::TransferBufferToDevice(
"%lld < %lld",
destination->size(), size);
}
- auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination);
- if (!copy_status.ok()) {
- return AddStatus(
- Status(static_cast<tensorflow::error::Code>(copy_status.code()),
- copy_status.error_message()),
- "failed transfer of buffer to device");
- }
+ stream->ThenMemcpy(destination, source, size);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 43a8092b06..e384359642 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -52,30 +52,65 @@ class TransferManager {
return host_shape;
}
- // Returns a literal containing the data held in the given ShapedBuffer.
- // using the provided executor. The optional literal_shape will be the shape
- // for the literal. The shape of the ShapedBuffer and
- // DeviceShape(literal_shape) must be compatible, but need not have the same
- // layout.
+ // Returns a literal containing the data held in the given ShapedBuffer
+ // using the provided executor. This operation is performed synchronously
+ // without waiting for any other operation on a stream to complete.
+ //
+ // This function should be avoided in favor of the asynchronous version below.
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0;
+ se::Stream* stream, const ShapedBuffer& device_buffer);
+
+ // Begins transferring a literal containing the data held in the given
+ // ShapedBuffer using the provided executor.
+ //
+ // This operation is performed asynchronously on the given stream. It returns
+ // once the transfer is enqueued. 'done' is invoked with the result when
+ // complete.
+ //
+ // device_buffer is copied by reference and must live at least until done() is
+ // invoked.
+ virtual void TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0;
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
- // but need not have the same layout
- virtual Status TransferLiteralToDevice(se::StreamExecutor* executor,
+ // but need not have the same layout.
+ //
+ // This operation is performed synchronously without waiting for any other
+ // operation on a stream to complete. This function should be avoided in favor
+ // of the asynchronous version below.
+ virtual Status TransferLiteralToDevice(se::Stream* stream,
const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) = 0;
+ const ShapedBuffer& device_buffer);
+
+ // Transfers the given literal into the previously allocated device memory
+ // represented by the given ShapedBuffer using the given executor. The shape
+ // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
+ // but need not have the same layout.
+ //
+ // This operation is performed asynchronously on the given stream. It returns
+ // once the transfer is enqueued.
+ virtual Status TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) = 0;
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to
// transfer an array at a known address.
- Status TransferArrayToDevice(se::StreamExecutor* executor,
- const LiteralSlice& literal,
+ Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
+ void TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done);
+
+ Status TransferArrayToDeviceAsync(se::Stream* stream,
+ const LiteralSlice& literal,
+ const se::DeviceMemoryBase& dest);
StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::StreamExecutor* executor, const Shape& shape,
+ se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
@@ -96,8 +131,10 @@ class TransferManager {
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
// ShapedBuffer is array-shaped this method does nothing.
- Status WriteTupleIndexTables(se::StreamExecutor* executor,
+ Status WriteTupleIndexTables(se::Stream* stream,
const ShapedBuffer& device_buffer);
+ Status WriteTupleIndexTablesAsync(se::Stream* stream,
+ const ShapedBuffer& device_buffer);
// Determines the byte size requirement for the given shape on the underlying
// architecture. This will be used to allocate an appropriately sized memory
@@ -144,7 +181,7 @@ class TransferManager {
// 'destination' buffer.
//
// size is the size to transfer to destination in bytes.
- virtual Status TransferBufferFromDevice(se::StreamExecutor* executor,
+ virtual Status TransferBufferFromDevice(se::Stream* stream,
const se::DeviceMemoryBase& source,
int64 size, void* destination);
@@ -152,15 +189,15 @@ class TransferManager {
// destination of the device.
//
// size is the size to transfer from source in bytes.
- virtual Status TransferBufferToDevice(se::StreamExecutor* executor,
- int64 size, const void* source,
+ virtual Status TransferBufferToDevice(se::Stream* stream, int64 size,
+ const void* source,
se::DeviceMemoryBase* destination);
// Writes the given device-memory pointers in 'elements' to the given region
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 51d45b2be6..e9d7178e3d 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -380,6 +380,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.tuple_shapes(index);
}
+/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
+ int64 n = 0;
+ ForEachSubshape(shape, [&](const Shape& literal_subshape,
+ const ShapeIndex& index) { ++n; });
+ return n;
+}
+
/* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
int64 limit) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
@@ -422,7 +429,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.element_type() == F32 && Rank(shape) == 0;
}
-
namespace {
// Class to memoize the computation of
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 25ed70316b..b7543c2026 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -457,6 +457,9 @@ class ShapeUtil {
// Precondition: IsTuple(shape) && TupleElementCount(shape) > index
static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
+ // Returns the number of elements, recursively, in the given shape.
+ static int64 SubshapeCount(const Shape& shape);
+
// Slices tuple elements in the range [start, limit) and returns a new tuple
// shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index e7e0a19db0..b76830f666 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1986,6 +1986,7 @@ xla_test(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 49f3a10d22..a918c91f07 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -716,8 +716,10 @@ void BM_DynamicSlice(int num_iters) {
.ConsumeValueOrDie();
auto start_indices_literal = Literal::CreateR1<int32>({0, 1, 2, 3});
+ auto stream =
+ client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- executors[device_ordinal], *start_indices_literal, buffer));
+ stream.get(), *start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 96858c00d6..5a70c2a9ae 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -209,13 +209,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -238,17 +237,14 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0, 0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralSlice(*result_literal, {0, 1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
options, DefaultExecutableRunOptions());
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -360,10 +355,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -389,18 +384,17 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
+ LiteralSlice(*result_0_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
+ LiteralSlice(*result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
+ LiteralSlice(*result_1_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
+ LiteralSlice(*result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -447,8 +441,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}),
- error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
}
}
@@ -547,8 +540,8 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
- LiteralTestUtil::ExpectR0Equal<float>(
- 165.0, LiteralSlice(*result_literal, index));
+ LiteralTestUtil::ExpectR0Equal<float>(165.0,
+ LiteralSlice(*result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -753,10 +746,10 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR1Equal<float>(
- {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1}));
+ LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
+ LiteralSlice(*tuple_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
+ LiteralSlice(*tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -900,8 +893,10 @@ void BM_LocalClientOverhead(int num_iters) {
->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
auto literal = Literal::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- executors[device_ordinal], *literal, buffer));
+ auto stream =
+ client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
+ buffer));
const int kWarmups = 2;
@@ -911,11 +906,8 @@ void BM_LocalClientOverhead(int num_iters) {
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
- se::Stream stream(executors[client->default_device_ordinal()]);
- stream.Init();
-
ExecutableRunOptions run_options;
- run_options.set_allocator(&allocator).set_stream(&stream);
+ run_options.set_allocator(&allocator).set_stream(stream.get());
for (int i = 0; i < kWarmups; ++i) {
auto result = executable->Run({&buffer}, run_options);
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 0063e7ad41..85799d4cfb 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -41,7 +42,12 @@ class TransferManagerTest : public LocalClientTestBase {
TransferManagerTest()
: shape_size_fn_([this](const Shape& shape) {
return transfer_manager_->GetByteSizeRequirement(shape);
- }) {}
+ }) {
+ stream_ptr_ = local_client_->mutable_backend()
+ ->BorrowStream(stream_executor_)
+ .ValueOrDie();
+ stream_ = stream_ptr_.get();
+ }
~TransferManagerTest() override = default;
@@ -53,6 +59,10 @@ class TransferManagerTest : public LocalClientTestBase {
.ValueOrDie();
}
+ protected:
+ Backend::StreamPtr stream_ptr_;
+ se::Stream* stream_;
+
private:
std::function<int64(const Shape&)> shape_size_fn_;
};
@@ -63,11 +73,11 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) {
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
}
@@ -79,11 +89,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) {
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
*result);
@@ -97,11 +107,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
}
@@ -113,11 +123,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) {
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_EQ(result->GetR1U8AsString(), test_string);
}
@@ -129,11 +139,11 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) {
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
@@ -149,11 +159,11 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
@@ -169,11 +179,11 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
@@ -183,11 +193,11 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
@@ -203,11 +213,11 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
@@ -218,11 +228,11 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
@@ -237,14 +247,150 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
+XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
+ const int64 kIterationCount = 5000;
+ std::unique_ptr<Literal> literal1 = Literal::MakeTuple(
+ {Literal::CreateR0<float>(123.0f).get(),
+ Literal::MakeTuple(
+ {Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
+ .get(),
+ Literal::CreateR1<float>({-10.0f, 123.0f}).get()});
+ std::unique_ptr<Literal> literal2 = Literal::MakeTuple(
+ {Literal::CreateR0<float>(456.0f).get(),
+ Literal::MakeTuple(
+ {Literal::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
+ Literal::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
+ .get(),
+ Literal::CreateR1<float>({-98.0f, 153.0f}).get()});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+
+ auto stream1 = stream_;
+ auto stream2 = stream_->GetOrCreateSubStream();
+
+ std::unique_ptr<Literal> result1, result2;
+
+ // Round trip literals through device in multiple streams asynchronously.
+ for (int i = 0; i < kIterationCount; ++i) {
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ device_buffer1));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ device_buffer2));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> this_result1,
+ transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> this_result2,
+ transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
+ result1 = std::move(this_result1);
+ result2 = std::move(this_result2);
+ }
+
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+}
+
+class TransferDeviceToHostBenchmark : public TransferManagerTest {
+ public:
+ using TransferManagerTest::TransferManagerTest;
+ ~TransferDeviceToHostBenchmark() override {}
+
+ void Run(int iters, int num_tuple_elements, int array_size) {
+ tensorflow::testing::StopTiming();
+ SetUp();
+
+ std::vector<std::unique_ptr<Literal>> tuple_elements;
+ for (int i = 0; i < num_tuple_elements; ++i) {
+ tuple_elements.push_back(
+ Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ }
+ std::unique_ptr<Literal> literal =
+ Literal::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
+ }
+ tensorflow::testing::StopTiming();
+ TearDown();
+ }
+
+ void TestBody() override {}
+};
+
+class TransferHostToDeviceBenchmark : public TransferManagerTest {
+ public:
+ using TransferManagerTest::TransferManagerTest;
+ ~TransferHostToDeviceBenchmark() override {}
+
+ void Run(int iters, int num_tuple_elements, int array_size) {
+ tensorflow::testing::StopTiming();
+ SetUp();
+
+ std::vector<std::unique_ptr<Literal>> tuple_elements;
+ for (int i = 0; i < num_tuple_elements; ++i) {
+ tuple_elements.push_back(
+ Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ }
+ std::unique_ptr<Literal> literal =
+ Literal::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ }
+ tensorflow::testing::StopTiming();
+ TearDown();
+ }
+
+ void TestBody() override {}
+};
+
+void BM_TransferDeviceToHost(int iters, int num_tuple_elements,
+ int array_size) {
+ TransferDeviceToHostBenchmark bm;
+ bm.Run(iters, num_tuple_elements, array_size);
+}
+
+void BM_TransferHostToDevice(int iters, int num_tuple_elements,
+ int array_size) {
+ TransferHostToDeviceBenchmark bm;
+ bm.Run(iters, num_tuple_elements, array_size);
+}
+
+BENCHMARK(BM_TransferHostToDevice)
+ ->ArgPair(1, 256)
+ ->ArgPair(1, 257)
+ ->ArgPair(100, 256)
+ ->ArgPair(100, 257);
+
+BENCHMARK(BM_TransferDeviceToHost)
+ ->ArgPair(1, 256)
+ ->ArgPair(1, 257)
+ ->ArgPair(100, 256)
+ ->ArgPair(100, 257);
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ tensorflow::testing::RunBenchmarks();
+ return RUN_ALL_TESTS();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 3c9a01653c..0be950cacb 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -128,20 +128,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
se::StreamExecutor* executor = backend->default_stream_executor();
DeviceMemoryAllocator* allocator = backend->memory_allocator();
auto* transfer_manager = backend->transfer_manager();
+ TF_ASSERT_OK_AND_ASSIGN(
+ Backend::StreamPtr stream_ptr,
+ backend->BorrowStream(backend->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer lhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
@@ -153,9 +156,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
- TF_ASSERT_OK_AND_ASSIGN(
- Backend::StreamPtr stream_ptr,
- backend->BorrowStream(backend->default_device_ordinal()));
ExecutableRunOptions exec_run_options;
exec_run_options.set_stream(stream_ptr.get());
exec_run_options.set_allocator(backend->memory_allocator());
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index a9f2915b45..a075195618 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) {
}
// Unfortunately Google's internal benchmark infrastructure has a
// different API than Tensorflow's.
+ testing::InitGoogleTest(&argc, argv);
#if defined(PLATFORM_GOOGLE)
base::SetFlag(&FLAGS_benchmarks, pattern);
RunSpecifiedBenchmarks();