diff options
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op_impl.h | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index 7d42887426..afe3a051e6 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -84,16 +84,16 @@ void HandleStridedSliceCase(OpKernelContext* context, gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes(); if (is_simple_slice) { - gtl::InlinedVector<int64, 4> sizes(begin.size()); + Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; + Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di; for (int i = 0; i < NDIM; ++i) { - sizes[i] = end[i] - begin[i]; + begin_di[i] = begin[i]; + sizes_di[i] = end[i] - begin[i]; } - const TensorShape final_shape = result->shape(); - CHECK(result->CopyFrom(*result, processing_shape)); - const Tensor input = context->input(0); - functor::Slice<Device, T, NDIM>()( - context->eigen_device<Device>(), result, input, begin, sizes); - CHECK(result->CopyFrom(*result, final_shape)); + functor::Slice<Device, Proxy, NDIM>()( + context->eigen_device<Device>(), + result->bit_casted_shaped<Proxy, NDIM>(processing_dims), + context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di); } else { Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; @@ -196,9 +196,10 @@ class HandleStridedSliceAssignCase<Device, T, 0> { extern template struct StridedSlice<GPUDevice, T, NDIM>; \ template <> \ void Slice<GPUDevice, T, NDIM>::operator()( \ - const GPUDevice& d, Tensor* output, const Tensor& input, \ - const gtl::ArraySlice<int64>& slice_indices, \ - const gtl::ArraySlice<int64>& slice_sizes); \ + const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ + typename TTypes<T, NDIM>::ConstTensor input, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ extern template struct Slice<GPUDevice, T, NDIM>; \ template <> \ void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \ @@ -283,6 +284,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU); TF_CALL_complex64(DECLARE_FOR_N_GPU); TF_CALL_complex128(DECLARE_FOR_N_GPU); DECLARE_FOR_N_GPU(int32); +DECLARE_FOR_N_GPU(int64); #endif // END GOOGLE_CUDA TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); @@ -298,6 +300,7 @@ DECLARE_FOR_N_CPU(bfloat16); TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL); DECLARE_FOR_N_SYCL(int32); +DECLARE_FOR_N_SYCL(int64); #undef DECLARE_FOR_N_SYCL #endif // TENSORFLOW_USE_SYCL |