diff options
author | 2016-09-20 14:46:57 -0800 | |
---|---|---|
committer | 2016-09-20 16:02:10 -0700 | |
commit | aa038879b2376239bd47a0375089007efe795fd6 (patch) | |
tree | 609a6ab52a99eeeb88ddeec1915aec6099220c8b /tensorflow/core/kernels/scan_ops.cc | |
parent | 63263dff15cf4e18469bc1ffcada7f3c58b53f21 (diff) |
Simplify scan_ops (cumsum, etc) and support any dimensions.
Change: 133764629
Diffstat (limited to 'tensorflow/core/kernels/scan_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/scan_ops.cc | 69 |
1 files changed, 23 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc index a51da2ddf5..c1ab9ea6b3 100644 --- a/tensorflow/core/kernels/scan_ops.cc +++ b/tensorflow/core/kernels/scan_ops.cc @@ -68,31 +68,19 @@ class ScanOp : public OpKernel { const Device& d = ctx->eigen_device<Device>(); Reducer reducer; -#define HANDLE_SCAN(NDIMS) \ - case NDIMS: \ - functor::Scan<Device, Reducer, T, NDIMS>()( \ - d, input.tensor<T, NDIMS>(), output->tensor<T, NDIMS>(), reducer, \ - axis, reverse_, exclusive_); \ - return; - - switch (input.dims()) { - // input.dims() == 0 can't occur as there - // is no valid axis parameter in this case - HANDLE_SCAN(1); - HANDLE_SCAN(2); - HANDLE_SCAN(3); - HANDLE_SCAN(4); - HANDLE_SCAN(5); - HANDLE_SCAN(6); - HANDLE_SCAN(7); - HANDLE_SCAN(8); - default: - OP_REQUIRES(ctx, false, errors::InvalidArgument( - "Scan does not support tensors with " - "more than 8 dimensions", - input.dims())); + // Dim reduction. + int64 reduced_shape[3] = {1, 1, 1}; + for (int i = 0; i < axis; ++i) { + reduced_shape[0] *= input.dim_size(i); } -#undef HANDLE_SCAN + reduced_shape[1] = input.dim_size(axis); + for (int i = axis + 1; i < input.dims(); ++i) { + reduced_shape[2] *= input.dim_size(i); + } + + functor::Scan<Device, Reducer, T>()(d, input.shaped<T, 3>(reduced_shape), + output->shaped<T, 3>(reduced_shape), + reducer, reverse_, exclusive_); } private: @@ -104,32 +92,21 @@ class ScanOp : public OpKernel { namespace functor { // Forward declarations of GPU functors -#define DECLARE(REDUCER, T, D) \ - template <> \ - void Scan<GPUDevice, REDUCER, T, D>::operator()( \ - const GPUDevice& d, TTypes<T, D>::ConstTensor in, \ - TTypes<T, D>::Tensor out, const REDUCER& reducer, \ - const Eigen::Index& axis, const bool reverse, const bool exclusive); \ - extern template struct Scan<GPUDevice, REDUCER, T, D>; - -#define DECLARE_FOR_ALL_DIMS(REDUCER, T) \ - DECLARE(REDUCER, T, 1); \ - DECLARE(REDUCER, T, 2); \ - DECLARE(REDUCER, T, 3); \ - DECLARE(REDUCER, T, 4); \ - DECLARE(REDUCER, T, 5); \ - DECLARE(REDUCER, T, 6); \ - DECLARE(REDUCER, T, 7); \ - DECLARE(REDUCER, T, 8); - -#define DECLARE_FOR_ALL_REDUCERS(T) \ - DECLARE_FOR_ALL_DIMS(Eigen::internal::SumReducer<T>, T); \ - DECLARE_FOR_ALL_DIMS(Eigen::internal::ProdReducer<T>, T); +#define DECLARE(REDUCER, T) \ + template <> \ + void Scan<GPUDevice, REDUCER, T>::operator()( \ + const GPUDevice& d, TTypes<T, 3>::ConstTensor in, \ + TTypes<T, 3>::Tensor out, const REDUCER& reducer, const bool reverse, \ + const bool exclusive); \ + extern template struct Scan<GPUDevice, REDUCER, T>; + +#define DECLARE_FOR_ALL_REDUCERS(T) \ + DECLARE(Eigen::internal::SumReducer<T>, T); \ + DECLARE(Eigen::internal::ProdReducer<T>, T); TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS); #undef DECLARE_FOR_ALL_REDUCERS -#undef DECLARE_FOR_ALL_DIMS #undef DECLARE } // namespace functor |