diff options
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op_impl.h | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index afe3a051e6..7d42887426 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -84,16 +84,16 @@ void HandleStridedSliceCase(OpKernelContext* context, gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes(); if (is_simple_slice) { - Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; - Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di; + gtl::InlinedVector<int64, 4> sizes(begin.size()); for (int i = 0; i < NDIM; ++i) { - begin_di[i] = begin[i]; - sizes_di[i] = end[i] - begin[i]; + sizes[i] = end[i] - begin[i]; } - functor::Slice<Device, Proxy, NDIM>()( - context->eigen_device<Device>(), - result->bit_casted_shaped<Proxy, NDIM>(processing_dims), - context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di); + const TensorShape final_shape = result->shape(); + CHECK(result->CopyFrom(*result, processing_shape)); + const Tensor input = context->input(0); + functor::Slice<Device, T, NDIM>()( + context->eigen_device<Device>(), result, input, begin, sizes); + CHECK(result->CopyFrom(*result, final_shape)); } else { Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; @@ -196,10 +196,9 @@ class HandleStridedSliceAssignCase<Device, T, 0> { extern template struct StridedSlice<GPUDevice, T, NDIM>; \ template <> \ void Slice<GPUDevice, T, NDIM>::operator()( \ - const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ - typename TTypes<T, NDIM>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ + const GPUDevice& d, Tensor* output, const Tensor& input, \ + const gtl::ArraySlice<int64>& slice_indices, \ + const gtl::ArraySlice<int64>& slice_sizes); \ extern template struct Slice<GPUDevice, T, NDIM>; \ template <> \ void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \ @@ -284,7 +283,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU); TF_CALL_complex64(DECLARE_FOR_N_GPU); TF_CALL_complex128(DECLARE_FOR_N_GPU); DECLARE_FOR_N_GPU(int32); -DECLARE_FOR_N_GPU(int64); #endif // END GOOGLE_CUDA TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); @@ -300,7 +298,6 @@ DECLARE_FOR_N_CPU(bfloat16); TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL); DECLARE_FOR_N_SYCL(int32); -DECLARE_FOR_N_SYCL(int64); #undef DECLARE_FOR_N_SYCL #endif // TENSORFLOW_USE_SYCL |