aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu_device_context.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/gpu_device_context.h')
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
new file mode 100644
index 0000000000..03fd9a97c3
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -0,0 +1,45 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+class GPUDeviceContext : public DeviceContext {
+ public:
+ GPUDeviceContext(int stream_id, gpu::Stream* stream)
+ : stream_id_(stream_id), stream_(stream) {}
+
+ ~GPUDeviceContext() override {}
+
+ gpu::Stream* stream() const override { return stream_; }
+ int stream_id() const { return stream_id_; }
+
+ void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const override;
+
+ void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& edge_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) override;
+
+ void MaintainLifetimeOnStream(
+ const Tensor* t, perftools::gputools::Stream* stream) const override {}
+
+ private:
+ int stream_id_;
+ gpu::Stream* stream_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_