diff options
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/shape_ops.cc | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 721f9b949b..28a39bae3f 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -341,7 +341,12 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .Device(DEVICE_CPU) .HostMemory("dim") .TypeConstraint<int32>("Tdim"), - ExpandDimsOp); + ExpandDimsOp<int32>); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_CPU) + .HostMemory("dim") + .TypeConstraint<int64>("Tdim"), + ExpandDimsOp<int64>); #if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ @@ -350,7 +355,13 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .TypeConstraint<type>("T") \ .TypeConstraint<int32>("Tdim") \ .HostMemory("dim"), \ - ExpandDimsOp); + ExpandDimsOp<int32>); \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("Tdim") \ + .HostMemory("dim"), \ + ExpandDimsOp<int64>); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); TF_CALL_bool(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL @@ -362,7 +373,15 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("input") .HostMemory("dim") .HostMemory("output"), - ExpandDimsOp); + ExpandDimsOp<int32>); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tdim") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp<int64>); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL @@ -372,7 +391,13 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .TypeConstraint<type>("T") \ .TypeConstraint<int32>("Tdim") \ .HostMemory("dim"), \ - ExpandDimsOp); + ExpandDimsOp<int32>); \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("Tdim") \ + .HostMemory("dim"), \ + ExpandDimsOp<int64>); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); TF_CALL_bool(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL @@ -384,7 +409,15 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("input") .HostMemory("dim") .HostMemory("output"), - ExpandDimsOp); + ExpandDimsOp<int32>); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tdim") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp<int64>); #endif // TENSORFLOW_USE_SYCL // Squeeze --------------------------------------- |