diff options
Diffstat (limited to 'tensorflow/core/kernels/slice_op.h')
-rw-r--r-- | tensorflow/core/kernels/slice_op.h | 109 |
1 files changed, 18 insertions, 91 deletions
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h index 55a4be985b..db7eded745 100644 --- a/tensorflow/core/kernels/slice_op.h +++ b/tensorflow/core/kernels/slice_op.h @@ -19,104 +19,31 @@ limitations under the License. // Functor definition for SliceOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { - -namespace internal { - -template <typename Device, typename T> -void SliceSimple(const Device& d, Tensor* out, const Tensor& in, - const gtl::ArraySlice<int64>& slice_indices); -template <typename Device, typename T> -void SliceSimpleGpu(const Device& d, Tensor* out, const Tensor& in, - const gtl::ArraySlice<int64>& slice_indices); - -template <typename Device, typename T> -void SliceSimple(const Device& d, Tensor* out, const Tensor& in, - const gtl::ArraySlice<int64>& slice_indices) { - const int ndims = in.dims(); - const int64 nelem = out->NumElements(); - const gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape()); - const gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape()); - const T* p = in.flat<T>().data(); - T* q = out->flat<T>().data(); - - std::vector<int64> i_idx(nelem, 0); - std::vector<int64> t(nelem, 0); - - for (int64 o_idx = 0; o_idx < nelem; ++o_idx) { - t[o_idx] = o_idx; - } - for (int i = 0; i < ndims; ++i) { - int64 n = (nelem + 7) / 8; - int64 o_idx = 0; - switch (nelem % 8) { -#define CALC_INPUT_IDX \ - i_idx[o_idx] += (t[o_idx] / out_strides[i] + slice_indices[i]) * in_strides[i]; \ - t[o_idx] %= out_strides[i]; \ - ++o_idx; - case 0: do { CALC_INPUT_IDX; - case 7: CALC_INPUT_IDX; - case 6: CALC_INPUT_IDX; - case 5: CALC_INPUT_IDX; - case 4: CALC_INPUT_IDX; - case 3: CALC_INPUT_IDX; - case 2: CALC_INPUT_IDX; - case 1: CALC_INPUT_IDX; -#undef CALC_INPUT_IDX - } while (--n > 0); - } - } - for (int64 o_idx = 0; o_idx < nelem; ++o_idx) { - q[o_idx] = p[i_idx[o_idx]]; - } -} - -template <typename Device, typename T, int NDIMS> -void SliceUsingEigen(const Device& d, Tensor* out, const Tensor& in, - const gtl::ArraySlice<int64>& slice_indices, - const gtl::ArraySlice<int64>& slice_sizes) { - auto input = in.tensor<T, NDIMS>(); - auto output = out->tensor<T, NDIMS>(); - Eigen::DSizes<int, NDIMS> indices; - for (int i = 0; i < NDIMS; ++i) { - indices[i] = slice_indices[i]; - } - Eigen::DSizes<int, NDIMS> sizes; - for (int i = 0; i < NDIMS; ++i) { - sizes[i] = slice_sizes[i]; - } - const bool use_64bit = input.size() > Eigen::NumTraits<int>::highest(); - if (!use_64bit && - Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) { - To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes); - } else { - output.device(d) = input.slice(indices, sizes); - } -} - -} // namespace internal - namespace functor { -// Template parameter NDIM is not neccesary here. The aim of keeping it -// is to compile struct slice seperately which minimizes the compiling time. -template <typename Device, typename T, int NDIM> +template <typename Device, typename T, int NDIMS> struct Slice { - void operator()(const Device& d, Tensor* out, const Tensor& in, - const gtl::ArraySlice<int64>& slice_indices, - const gtl::ArraySlice<int64>& slice_sizes) { - if (in.dims() == NDIM) { - internal::SliceUsingEigen<Device, T, NDIM>(d, out, in, slice_indices, slice_sizes); + void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output, + typename TTypes<T, NDIMS>::ConstTensor input, + const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices, + const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) { + bool use_64bit = (input.size() > Eigen::NumTraits<int>::highest()); + if (!use_64bit && + Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) { + Eigen::DSizes<int, NDIMS> indices; + for (int i = 0; i < NDIMS; ++i) { + indices[i] = slice_indices[i]; + } + Eigen::DSizes<int, NDIMS> sizes; + for (int i = 0; i < NDIMS; ++i) { + sizes[i] = slice_sizes[i]; + } + To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes); } else { - if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) { - internal::SliceSimpleGpu<Device, T>(d, out, in, slice_indices); - } else { - internal::SliceSimple<Device, T>(d, out, in, slice_indices); - } + output.device(d) = input.slice(slice_indices, slice_sizes); } } }; |