aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/strided_slice_op_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_impl.h')
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h25
1 files changed, 11 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index afe3a051e6..7d42887426 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -84,16 +84,16 @@ void HandleStridedSliceCase(OpKernelContext* context,
gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
if (is_simple_slice) {
- Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
+ gtl::InlinedVector<int64, 4> sizes(begin.size());
for (int i = 0; i < NDIM; ++i) {
- begin_di[i] = begin[i];
- sizes_di[i] = end[i] - begin[i];
+ sizes[i] = end[i] - begin[i];
}
- functor::Slice<Device, Proxy, NDIM>()(
- context->eigen_device<Device>(),
- result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
- context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
+ const TensorShape final_shape = result->shape();
+ CHECK(result->CopyFrom(*result, processing_shape));
+ const Tensor input = context->input(0);
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result, input, begin, sizes);
+ CHECK(result->CopyFrom(*result, final_shape));
} else {
Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
@@ -196,10 +196,9 @@ class HandleStridedSliceAssignCase<Device, T, 0> {
extern template struct StridedSlice<GPUDevice, T, NDIM>; \
template <> \
void Slice<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
- typename TTypes<T, NDIM>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
+ const GPUDevice& d, Tensor* output, const Tensor& input, \
+ const gtl::ArraySlice<int64>& slice_indices, \
+ const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<GPUDevice, T, NDIM>; \
template <> \
void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \
@@ -284,7 +283,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
TF_CALL_complex64(DECLARE_FOR_N_GPU);
TF_CALL_complex128(DECLARE_FOR_N_GPU);
DECLARE_FOR_N_GPU(int32);
-DECLARE_FOR_N_GPU(int64);
#endif // END GOOGLE_CUDA
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
@@ -300,7 +298,6 @@ DECLARE_FOR_N_CPU(bfloat16);
TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
DECLARE_FOR_N_SYCL(int32);
-DECLARE_FOR_N_SYCL(int64);
#undef DECLARE_FOR_N_SYCL
#endif // TENSORFLOW_USE_SYCL