diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/cpu_executable.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_executable.cc | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 946f5124b8..c376864c3e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -249,24 +249,11 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - std::vector<OwningDeviceMemory> owning_buffers; - std::vector<se::DeviceMemoryBase> unowning_buffers; TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &owning_buffers); + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( @@ -277,6 +264,16 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast<se::host::HostStream*>( run_options->stream()->implementation()); @@ -310,19 +307,20 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( ServiceExecutableRunOptions run_options; std::vector<se::DeviceMemoryBase> unowning_buffers; std::shared_ptr<std::vector<OwningDeviceMemory>> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; host_stream->EnqueueTask( AsyncRunTask{this, *run_options, std::move(unowning_buffers), std::make_shared<std::vector<OwningDeviceMemory>>( - std::move(owning_buffers))}); + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } |