aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/scan_ops.cc98
-rw-r--r--tensorflow/python/kernel_tests/scan_ops_test.py18
2 files changed, 80 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc
index cc434ab0ae..0a6848361a 100644
--- a/tensorflow/core/kernels/scan_ops.cc
+++ b/tensorflow/core/kernels/scan_ops.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, class T, typename Reducer>
+template <typename Device, class T, typename Reducer, typename Tidx>
class ScanOp : public OpKernel {
public:
explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -51,8 +51,9 @@ class ScanOp : public OpKernel {
errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
tensor_axis.shape().DebugString()));
- const int axis_arg = internal::SubtleMustCopy(tensor_axis.scalar<int>()());
- const int axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
+ const Tidx axis_arg =
+ internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()());
+ const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()),
errors::InvalidArgument(
"ScanOp: Expected scan axis in the range [", -input.dims(),
@@ -70,11 +71,11 @@ class ScanOp : public OpKernel {
// Dim reduction.
int64 reduced_shape[3] = {1, 1, 1};
- for (int i = 0; i < axis; ++i) {
+ for (Tidx i = 0; i < axis; ++i) {
reduced_shape[0] *= input.dim_size(i);
}
reduced_shape[1] = input.dim_size(axis);
- for (int i = axis + 1; i < input.dims(); ++i) {
+ for (Tidx i = axis + 1; i < input.dims(); ++i) {
reduced_shape[2] *= input.dim_size(i);
}
@@ -112,51 +113,76 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
} // namespace functor
#endif // GOOGLE_CUDA
-
// Register Cumsum kernels
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Cumsum") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>>)
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumsum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumsum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Cumsum") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("axis"), \
- ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>>)
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumsum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis"), \
+ ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumsum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
// Register Cumprod kernels
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Cumprod") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>)
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumprod") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumprod") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Cumprod") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("axis"), \
- ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>)
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumprod") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis"), \
+ ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cumprod") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
diff --git a/tensorflow/python/kernel_tests/scan_ops_test.py b/tensorflow/python/kernel_tests/scan_ops_test.py
index 6b2b589a06..08b4a2aaae 100644
--- a/tensorflow/python/kernel_tests/scan_ops_test.py
+++ b/tensorflow/python/kernel_tests/scan_ops_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradient_checker
@@ -92,6 +94,14 @@ class CumsumTest(test.TestCase):
for axis in (-1, 0):
self._compareAll(x, axis)
+ def testAxisType(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis_dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True):
+ axis = constant_op.constant(0, axis_dtype)
+ tf_out = math_ops.cumsum(x, axis).eval()
+
def test1D(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
@@ -190,6 +200,14 @@ class CumprodTest(test.TestCase):
for axis in (-1, 0):
self._compareAll(x, axis)
+ def testAxisType(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis_dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True):
+ axis = constant_op.constant(0, axis_dtype)
+ tf_out = math_ops.cumprod(x, axis).eval()
+
def test1D(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)