aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 15:21:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 15:26:06 -0700
commite9ef2338007bac406f4c8d00ae86fa653092d1fb (patch)
tree728114239cd02e1a59941416bc1b658b951b3ee0 /tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
parentfb7b8458751cc0ef38e4c98430517ab134f85aca (diff)
Add ThenExecute() method to DeviceContext interface.
This basically exposes the GPU EventMgr::ThenExecute function to code that can be compiled without explicitly depending on StreamExecutor and similar. It allows some corner cases in GPU memory use to be handled in code that is not just GPU-specific. PiperOrigin-RevId: 210010480
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc')
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
index ea1b04feeb..4bc88ffc8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
@@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
}
+Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) {
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ device->tensorflow_gpu_device_info();
+ gpu_info->event_mgr->ThenExecute(stream, func);
+ return Status::OK();
+}
+
} // namespace tensorflow