diff options
author | 2018-08-23 15:21:42 -0700 | |
---|---|---|
committer | 2018-08-23 15:26:06 -0700 | |
commit | e9ef2338007bac406f4c8d00ae86fa653092d1fb (patch) | |
tree | 728114239cd02e1a59941416bc1b658b951b3ee0 /tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc | |
parent | fb7b8458751cc0ef38e4c98430517ab134f85aca (diff) |
Add ThenExecute() method to DeviceContext interface.
This basically exposes the GPU EventMgr::ThenExecute function to code
that can be compiled without explicitly depending on StreamExecutor and
similar. It allows some corner cases in GPU memory use to be handled in
code that is not just GPU-specific.
PiperOrigin-RevId: 210010480
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc')
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index ea1b04feeb..4bc88ffc8c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/tensor.h" @@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); } +Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, + std::function<void()> func) { + const DeviceBase::GpuDeviceInfo* gpu_info = + device->tensorflow_gpu_device_info(); + gpu_info->event_mgr->ThenExecute(stream, func); + return Status::OK(); +} + } // namespace tensorflow |