diff options
Diffstat (limited to 'tensorflow/core/common_runtime/gpu_device_context.h')
-rw-r--r-- | tensorflow/core/common_runtime/gpu_device_context.h | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h new file mode 100644 index 0000000000..03fd9a97c3 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -0,0 +1,45 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_base.h" + +namespace perftools { +namespace gputools { +class Stream; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +namespace gpu = ::perftools::gputools; + +class GPUDeviceContext : public DeviceContext { + public: + GPUDeviceContext(int stream_id, gpu::Stream* stream) + : stream_id_(stream_id), stream_(stream) {} + + ~GPUDeviceContext() override {} + + gpu::Stream* stream() const override { return stream_; } + int stream_id() const { return stream_id_; } + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + const string& edge_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void MaintainLifetimeOnStream( + const Tensor* t, perftools::gputools::Stream* stream) const override {} + + private: + int stream_id_; + gpu::Stream* stream_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ |