diff options
Diffstat (limited to 'tensorflow/core/kernels/sequence_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/sequence_ops.cc | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index c8ea923020..e2e3758d87 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -96,7 +96,7 @@ TF_CALL_double(REGISTER_SYCL_KERNEL); TF_CALL_int32(REGISTER_SYCL_KERNEL); TF_CALL_int64(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); @@ -116,7 +116,7 @@ TF_CALL_int64(REGISTER_GPU_KERNEL); #undef REGISTER_CPU_KERNEL #undef REGISTER_GPU_KERNEL -template <typename T> +template <typename T, typename Tnum> class LinSpaceOp : public OpKernel { public: explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -136,7 +136,7 @@ class LinSpaceOp : public OpKernel { num_in.shape().DebugString())); const T start = start_in.scalar<T>()(); const T stop = stop_in.scalar<T>()(); - const int32 num = num_in.scalar<int32>()(); + const Tnum num = num_in.scalar<Tnum>()(); OP_REQUIRES(context, num > 0, errors::InvalidArgument("Requires num > 0: ", num)); Tensor* out = nullptr; @@ -147,34 +147,46 @@ class LinSpaceOp : public OpKernel { flat(0) = start; } else { const T step = (stop - start) / (num - 1); - for (int32 i = 0; i < num; ++i) flat(i) = start + step * i; + for (Tnum i = 0; i < num; ++i) flat(i) = start + step * i; } } }; -#define REGISTER_KERNEL(DEV, T) \ - REGISTER_KERNEL_BUILDER(Name("LinSpace") \ - .Device(DEV) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<int32>("Tidx") \ - .HostMemory("start") \ - .HostMemory("stop") \ - .HostMemory("num") \ - .HostMemory("output"), \ - LinSpaceOp<T>); -#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T) +#define REGISTER_KERNEL(DEV, T, Tidx) \ + REGISTER_KERNEL_BUILDER(Name("LinSpace") \ + .Device(DEV) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tidx>("Tidx") \ + .HostMemory("start") \ + .HostMemory("stop") \ + .HostMemory("num") \ + .HostMemory("output"), \ + LinSpaceOp<T, Tidx>); + +#define REGISTER_KERNEL_ALL_NUMS(dev, T) \ + REGISTER_KERNEL(dev, T, int32); \ + REGISTER_KERNEL(dev, T, int64) + +#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T) TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); // NOTE(touts): We register the op on GPU but it still runs on CPU // because its inputs and outputs are tagged as HostMemory. -#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T) +#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_GPU, T) TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T) +#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_SYCL, T) TF_CALL_float(REGISTER_SYCL_KERNEL); TF_CALL_double(REGISTER_SYCL_KERNEL); -#endif // TENSORFLOW_USE_SYCL +#undef REGISTER_SYCL_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#undef REGISTER_CPU_KERNEL +#undef REGISTER_KERNEL_ALL_NUMS +#undef REGISTER_KERNEL + } // namespace tensorflow |