diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_allocator.h')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_allocator.h | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h index 15d9ab41a4..8668cba06a 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h @@ -28,17 +28,19 @@ namespace tensorflow { class SYCLAllocator : public Allocator { public: - SYCLAllocator(Eigen::QueueInterface *device) : device_(device) {} + SYCLAllocator(Eigen::QueueInterface *queue) : sycl_device_(new Eigen::SyclDevice(queue)) {} virtual ~SYCLAllocator() override; string Name() override; void *AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void *ptr) override; - void EnterLameDuckMode(); virtual bool ShouldAllocateEmptyTensors() override final { return true; } - + void Synchronize() { sycl_device_->synchronize(); } + bool Ok() { return sycl_device_->ok(); } + Eigen::SyclDevice* getSyclDevice() { return sycl_device_; } private: - Eigen::QueueInterface *device_; // not owned + Eigen::SyclDevice *sycl_device_; // owned + TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator); }; |