aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/slice_op.h')
-rw-r--r--tensorflow/core/kernels/slice_op.h109
1 files changed, 91 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index db7eded745..55a4be985b 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -19,31 +19,104 @@ limitations under the License.
// Functor definition for SliceOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
-namespace functor {
+
+namespace internal {
+
+template <typename Device, typename T>
+void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int64>& slice_indices);
+template <typename Device, typename T>
+void SliceSimpleGpu(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int64>& slice_indices);
+
+template <typename Device, typename T>
+void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int64>& slice_indices) {
+ const int ndims = in.dims();
+ const int64 nelem = out->NumElements();
+ const gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
+ const gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
+ const T* p = in.flat<T>().data();
+ T* q = out->flat<T>().data();
+
+ std::vector<int64> i_idx(nelem, 0);
+ std::vector<int64> t(nelem, 0);
+
+ for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
+ t[o_idx] = o_idx;
+ }
+ for (int i = 0; i < ndims; ++i) {
+ int64 n = (nelem + 7) / 8;
+ int64 o_idx = 0;
+ switch (nelem % 8) {
+#define CALC_INPUT_IDX \
+ i_idx[o_idx] += (t[o_idx] / out_strides[i] + slice_indices[i]) * in_strides[i]; \
+ t[o_idx] %= out_strides[i]; \
+ ++o_idx;
+ case 0: do { CALC_INPUT_IDX;
+ case 7: CALC_INPUT_IDX;
+ case 6: CALC_INPUT_IDX;
+ case 5: CALC_INPUT_IDX;
+ case 4: CALC_INPUT_IDX;
+ case 3: CALC_INPUT_IDX;
+ case 2: CALC_INPUT_IDX;
+ case 1: CALC_INPUT_IDX;
+#undef CALC_INPUT_IDX
+ } while (--n > 0);
+ }
+ }
+ for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
+ q[o_idx] = p[i_idx[o_idx]];
+ }
+}
template <typename Device, typename T, int NDIMS>
+void SliceUsingEigen(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int64>& slice_indices,
+ const gtl::ArraySlice<int64>& slice_sizes) {
+ auto input = in.tensor<T, NDIMS>();
+ auto output = out->tensor<T, NDIMS>();
+ Eigen::DSizes<int, NDIMS> indices;
+ for (int i = 0; i < NDIMS; ++i) {
+ indices[i] = slice_indices[i];
+ }
+ Eigen::DSizes<int, NDIMS> sizes;
+ for (int i = 0; i < NDIMS; ++i) {
+ sizes[i] = slice_sizes[i];
+ }
+ const bool use_64bit = input.size() > Eigen::NumTraits<int>::highest();
+ if (!use_64bit &&
+ Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
+ To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
+ } else {
+ output.device(d) = input.slice(indices, sizes);
+ }
+}
+
+} // namespace internal
+
+namespace functor {
+
+// Template parameter NDIM is not neccesary here. The aim of keeping it
+// is to compile struct slice seperately which minimizes the compiling time.
+template <typename Device, typename T, int NDIM>
struct Slice {
- void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
- typename TTypes<T, NDIMS>::ConstTensor input,
- const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
- const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
- bool use_64bit = (input.size() > Eigen::NumTraits<int>::highest());
- if (!use_64bit &&
- Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
- Eigen::DSizes<int, NDIMS> indices;
- for (int i = 0; i < NDIMS; ++i) {
- indices[i] = slice_indices[i];
- }
- Eigen::DSizes<int, NDIMS> sizes;
- for (int i = 0; i < NDIMS; ++i) {
- sizes[i] = slice_sizes[i];
- }
- To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
+ void operator()(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int64>& slice_indices,
+ const gtl::ArraySlice<int64>& slice_sizes) {
+ if (in.dims() == NDIM) {
+ internal::SliceUsingEigen<Device, T, NDIM>(d, out, in, slice_indices, slice_sizes);
} else {
- output.device(d) = input.slice(slice_indices, slice_sizes);
+ if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
+ internal::SliceSimpleGpu<Device, T>(d, out, in, slice_indices);
+ } else {
+ internal::SliceSimple<Device, T>(d, out, in, slice_indices);
+ }
}
}
};