aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/session_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/session_ops.cc')
-rw-r--r--tensorflow/core/kernels/session_ops.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index 3f1538164c..4550115c19 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -67,6 +67,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
+ GetSessionHandleOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+REGISTER_SYCL_KERNEL(bool);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class GetSessionTensorOp : public OpKernel {
public:
explicit GetSessionTensorOp(OpKernelConstruction* context)
@@ -97,6 +110,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("dtype"), \
+ GetSessionTensorOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+REGISTER_SYCL_KERNEL(bool);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class DeleteSessionTensorOp : public OpKernel {
public:
explicit DeleteSessionTensorOp(OpKernelConstruction* context)
@@ -117,4 +143,9 @@ REGISTER_KERNEL_BUILDER(
Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
DeleteSessionTensorOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
+ DeleteSessionTensorOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow