aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sequence_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sequence_ops.cc')
-rw-r--r--tensorflow/core/kernels/sequence_ops.cc48
1 files changed, 30 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc
index c8ea923020..e2e3758d87 100644
--- a/tensorflow/core/kernels/sequence_ops.cc
+++ b/tensorflow/core/kernels/sequence_ops.cc
@@ -96,7 +96,7 @@ TF_CALL_double(REGISTER_SYCL_KERNEL);
TF_CALL_int32(REGISTER_SYCL_KERNEL);
TF_CALL_int64(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
@@ -116,7 +116,7 @@ TF_CALL_int64(REGISTER_GPU_KERNEL);
#undef REGISTER_CPU_KERNEL
#undef REGISTER_GPU_KERNEL
-template <typename T>
+template <typename T, typename Tnum>
class LinSpaceOp : public OpKernel {
public:
explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -136,7 +136,7 @@ class LinSpaceOp : public OpKernel {
num_in.shape().DebugString()));
const T start = start_in.scalar<T>()();
const T stop = stop_in.scalar<T>()();
- const int32 num = num_in.scalar<int32>()();
+ const Tnum num = num_in.scalar<Tnum>()();
OP_REQUIRES(context, num > 0,
errors::InvalidArgument("Requires num > 0: ", num));
Tensor* out = nullptr;
@@ -147,34 +147,46 @@ class LinSpaceOp : public OpKernel {
flat(0) = start;
} else {
const T step = (stop - start) / (num - 1);
- for (int32 i = 0; i < num; ++i) flat(i) = start + step * i;
+ for (Tnum i = 0; i < num; ++i) flat(i) = start + step * i;
}
}
};
-#define REGISTER_KERNEL(DEV, T) \
- REGISTER_KERNEL_BUILDER(Name("LinSpace") \
- .Device(DEV) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("start") \
- .HostMemory("stop") \
- .HostMemory("num") \
- .HostMemory("output"), \
- LinSpaceOp<T>);
-#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
+#define REGISTER_KERNEL(DEV, T, Tidx) \
+ REGISTER_KERNEL_BUILDER(Name("LinSpace") \
+ .Device(DEV) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tidx>("Tidx") \
+ .HostMemory("start") \
+ .HostMemory("stop") \
+ .HostMemory("num") \
+ .HostMemory("output"), \
+ LinSpaceOp<T, Tidx>);
+
+#define REGISTER_KERNEL_ALL_NUMS(dev, T) \
+ REGISTER_KERNEL(dev, T, int32); \
+ REGISTER_KERNEL(dev, T, int64)
+
+#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T)
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
// NOTE(touts): We register the op on GPU but it still runs on CPU
// because its inputs and outputs are tagged as HostMemory.
-#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
+#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_GPU, T)
TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
+#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_SYCL, T)
TF_CALL_float(REGISTER_SYCL_KERNEL);
TF_CALL_double(REGISTER_SYCL_KERNEL);
-#endif // TENSORFLOW_USE_SYCL
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
+#undef REGISTER_CPU_KERNEL
+#undef REGISTER_KERNEL_ALL_NUMS
+#undef REGISTER_KERNEL
+
} // namespace tensorflow