aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scan_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-20 14:46:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-20 16:02:10 -0700
commitaa038879b2376239bd47a0375089007efe795fd6 (patch)
tree609a6ab52a99eeeb88ddeec1915aec6099220c8b /tensorflow/core/kernels/scan_ops.cc
parent63263dff15cf4e18469bc1ffcada7f3c58b53f21 (diff)
Simplify scan_ops (cumsum, etc) and support any dimensions.
Change: 133764629
Diffstat (limited to 'tensorflow/core/kernels/scan_ops.cc')
-rw-r--r--tensorflow/core/kernels/scan_ops.cc69
1 files changed, 23 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc
index a51da2ddf5..c1ab9ea6b3 100644
--- a/tensorflow/core/kernels/scan_ops.cc
+++ b/tensorflow/core/kernels/scan_ops.cc
@@ -68,31 +68,19 @@ class ScanOp : public OpKernel {
const Device& d = ctx->eigen_device<Device>();
Reducer reducer;
-#define HANDLE_SCAN(NDIMS) \
- case NDIMS: \
- functor::Scan<Device, Reducer, T, NDIMS>()( \
- d, input.tensor<T, NDIMS>(), output->tensor<T, NDIMS>(), reducer, \
- axis, reverse_, exclusive_); \
- return;
-
- switch (input.dims()) {
- // input.dims() == 0 can't occur as there
- // is no valid axis parameter in this case
- HANDLE_SCAN(1);
- HANDLE_SCAN(2);
- HANDLE_SCAN(3);
- HANDLE_SCAN(4);
- HANDLE_SCAN(5);
- HANDLE_SCAN(6);
- HANDLE_SCAN(7);
- HANDLE_SCAN(8);
- default:
- OP_REQUIRES(ctx, false, errors::InvalidArgument(
- "Scan does not support tensors with "
- "more than 8 dimensions",
- input.dims()));
+ // Dim reduction.
+ int64 reduced_shape[3] = {1, 1, 1};
+ for (int i = 0; i < axis; ++i) {
+ reduced_shape[0] *= input.dim_size(i);
}
-#undef HANDLE_SCAN
+ reduced_shape[1] = input.dim_size(axis);
+ for (int i = axis + 1; i < input.dims(); ++i) {
+ reduced_shape[2] *= input.dim_size(i);
+ }
+
+ functor::Scan<Device, Reducer, T>()(d, input.shaped<T, 3>(reduced_shape),
+ output->shaped<T, 3>(reduced_shape),
+ reducer, reverse_, exclusive_);
}
private:
@@ -104,32 +92,21 @@ class ScanOp : public OpKernel {
namespace functor {
// Forward declarations of GPU functors
-#define DECLARE(REDUCER, T, D) \
- template <> \
- void Scan<GPUDevice, REDUCER, T, D>::operator()( \
- const GPUDevice& d, TTypes<T, D>::ConstTensor in, \
- TTypes<T, D>::Tensor out, const REDUCER& reducer, \
- const Eigen::Index& axis, const bool reverse, const bool exclusive); \
- extern template struct Scan<GPUDevice, REDUCER, T, D>;
-
-#define DECLARE_FOR_ALL_DIMS(REDUCER, T) \
- DECLARE(REDUCER, T, 1); \
- DECLARE(REDUCER, T, 2); \
- DECLARE(REDUCER, T, 3); \
- DECLARE(REDUCER, T, 4); \
- DECLARE(REDUCER, T, 5); \
- DECLARE(REDUCER, T, 6); \
- DECLARE(REDUCER, T, 7); \
- DECLARE(REDUCER, T, 8);
-
-#define DECLARE_FOR_ALL_REDUCERS(T) \
- DECLARE_FOR_ALL_DIMS(Eigen::internal::SumReducer<T>, T); \
- DECLARE_FOR_ALL_DIMS(Eigen::internal::ProdReducer<T>, T);
+#define DECLARE(REDUCER, T) \
+ template <> \
+ void Scan<GPUDevice, REDUCER, T>::operator()( \
+ const GPUDevice& d, TTypes<T, 3>::ConstTensor in, \
+ TTypes<T, 3>::Tensor out, const REDUCER& reducer, const bool reverse, \
+ const bool exclusive); \
+ extern template struct Scan<GPUDevice, REDUCER, T>;
+
+#define DECLARE_FOR_ALL_REDUCERS(T) \
+ DECLARE(Eigen::internal::SumReducer<T>, T); \
+ DECLARE(Eigen::internal::ProdReducer<T>, T);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
#undef DECLARE_FOR_ALL_REDUCERS
-#undef DECLARE_FOR_ALL_DIMS
#undef DECLARE
} // namespace functor