aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/slice_op.h')
-rw-r--r--tensorflow/core/kernels/slice_op.h109
1 files changed, 18 insertions, 91 deletions
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index 55a4be985b..db7eded745 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -19,104 +19,31 @@ limitations under the License.
// Functor definition for SliceOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
-
-namespace internal {
-
-template <typename Device, typename T>
-void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
- const gtl::ArraySlice<int64>& slice_indices);
-template <typename Device, typename T>
-void SliceSimpleGpu(const Device& d, Tensor* out, const Tensor& in,
- const gtl::ArraySlice<int64>& slice_indices);
-
-template <typename Device, typename T>
-void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
- const gtl::ArraySlice<int64>& slice_indices) {
- const int ndims = in.dims();
- const int64 nelem = out->NumElements();
- const gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
- const gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
- const T* p = in.flat<T>().data();
- T* q = out->flat<T>().data();
-
- std::vector<int64> i_idx(nelem, 0);
- std::vector<int64> t(nelem, 0);
-
- for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
- t[o_idx] = o_idx;
- }
- for (int i = 0; i < ndims; ++i) {
- int64 n = (nelem + 7) / 8;
- int64 o_idx = 0;
- switch (nelem % 8) {
-#define CALC_INPUT_IDX \
- i_idx[o_idx] += (t[o_idx] / out_strides[i] + slice_indices[i]) * in_strides[i]; \
- t[o_idx] %= out_strides[i]; \
- ++o_idx;
- case 0: do { CALC_INPUT_IDX;
- case 7: CALC_INPUT_IDX;
- case 6: CALC_INPUT_IDX;
- case 5: CALC_INPUT_IDX;
- case 4: CALC_INPUT_IDX;
- case 3: CALC_INPUT_IDX;
- case 2: CALC_INPUT_IDX;
- case 1: CALC_INPUT_IDX;
-#undef CALC_INPUT_IDX
- } while (--n > 0);
- }
- }
- for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
- q[o_idx] = p[i_idx[o_idx]];
- }
-}
-
-template <typename Device, typename T, int NDIMS>
-void SliceUsingEigen(const Device& d, Tensor* out, const Tensor& in,
- const gtl::ArraySlice<int64>& slice_indices,
- const gtl::ArraySlice<int64>& slice_sizes) {
- auto input = in.tensor<T, NDIMS>();
- auto output = out->tensor<T, NDIMS>();
- Eigen::DSizes<int, NDIMS> indices;
- for (int i = 0; i < NDIMS; ++i) {
- indices[i] = slice_indices[i];
- }
- Eigen::DSizes<int, NDIMS> sizes;
- for (int i = 0; i < NDIMS; ++i) {
- sizes[i] = slice_sizes[i];
- }
- const bool use_64bit = input.size() > Eigen::NumTraits<int>::highest();
- if (!use_64bit &&
- Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
- To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
- } else {
- output.device(d) = input.slice(indices, sizes);
- }
-}
-
-} // namespace internal
-
namespace functor {
-// Template parameter NDIM is not neccesary here. The aim of keeping it
-// is to compile struct slice seperately which minimizes the compiling time.
-template <typename Device, typename T, int NDIM>
+template <typename Device, typename T, int NDIMS>
struct Slice {
- void operator()(const Device& d, Tensor* out, const Tensor& in,
- const gtl::ArraySlice<int64>& slice_indices,
- const gtl::ArraySlice<int64>& slice_sizes) {
- if (in.dims() == NDIM) {
- internal::SliceUsingEigen<Device, T, NDIM>(d, out, in, slice_indices, slice_sizes);
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
+ typename TTypes<T, NDIMS>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
+ bool use_64bit = (input.size() > Eigen::NumTraits<int>::highest());
+ if (!use_64bit &&
+ Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
+ Eigen::DSizes<int, NDIMS> indices;
+ for (int i = 0; i < NDIMS; ++i) {
+ indices[i] = slice_indices[i];
+ }
+ Eigen::DSizes<int, NDIMS> sizes;
+ for (int i = 0; i < NDIMS; ++i) {
+ sizes[i] = slice_sizes[i];
+ }
+ To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
} else {
- if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
- internal::SliceSimpleGpu<Device, T>(d, out, in, slice_indices);
- } else {
- internal::SliceSimple<Device, T>(d, out, in, slice_indices);
- }
+ output.device(d) = input.slice(slice_indices, slice_sizes);
}
}
};