aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/variable_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/variable_ops.cc46
1 files changed, 21 insertions, 25 deletions
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 7a4d9dc650..36b8ff09d7 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -32,33 +32,29 @@ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
IsVariableInitializedOp);
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("Variable") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("dtype"), \
- VariableOp); \
- REGISTER_KERNEL_BUILDER(Name("VariableV2") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("dtype"), \
- VariableOp); \
- REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("dtype"), \
- TemporaryVariableOp); \
- REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("T"), \
- DestroyTemporaryVariableOp); \
- REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("dtype") \
- .HostMemory("is_initialized"), \
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"),\
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("dtype"), \
+ TemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T"), \
+ DestroyTemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("is_initialized"), \
IsVariableInitializedOp);
-REGISTER_SYCL_KERNEL(float);
-REGISTER_SYCL_KERNEL(double);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL