diff options
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op.cc | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 59fdc2262a..3e8a4c5b72 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel { // NDIM and T if (is_simple_slice && std::is_same<Device, CPUDevice>::value && input_dims == 2 && processing_shape.dims() == 2 && - final_shape.dims() == 2) { + final_shape.dims() == 2 && new_axis_mask == 0) { MemCpyFunctor<T> functor; if (functor.Copy(input, begin, end, result)) { return; @@ -300,37 +300,40 @@ class StridedSliceAssignOp : public OpKernel { gtl::InlinedVector<int64, 4> end; gtl::InlinedVector<int64, 4> strides; - Tensor old_lhs; + Tensor* old_lhs = nullptr; + Tensor tmp; if (context->input_dtype(0) == DT_RESOURCE) { Var* v; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); + core::ScopedUnref scoped_unref(v); mutex_lock ml(*v->mu()); OP_REQUIRES_OK(context, PrepareToUpdateVariable<Device, T>(context, v->tensor())); - old_lhs = *v->tensor(); - OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value, + old_lhs = v->tensor(); + OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value, errors::InvalidArgument( - "l-value dtype ", DataTypeString(old_lhs.dtype()), + "l-value dtype ", DataTypeString(old_lhs->dtype()), " does not match r-value dtype ", DataTypeString(DataTypeToEnum<T>::value))); } else { context->forward_ref_input_to_ref_output(0, 0); - old_lhs = context->mutable_input(0, true); + tmp = context->mutable_input(0, true); + old_lhs = &tmp; } OP_REQUIRES_OK( - context, - ValidateStridedSliceOp( - &context->input(1), &context->input(2), context->input(3), - old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, - shrink_axis_mask, &processing_shape, &final_shape, &is_identity, - &is_simple_slice, &slice_dim0, &begin, &end, &strides)); + context, ValidateStridedSliceOp( + &context->input(1), &context->input(2), context->input(3), + old_lhs->shape(), begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask, &processing_shape, + &final_shape, &is_identity, &is_simple_slice, &slice_dim0, + &begin, &end, &strides)); if (processing_shape.num_elements()) { const Tensor& input = context->input(4); TensorShape input_shape = input.shape(); - TensorShape original_shape = old_lhs.shape(); + TensorShape original_shape = old_lhs->shape(); // TODO(aselle): This check is too strong, we only should need // input_shape to be broadcastable to final_shape OP_REQUIRES( @@ -345,12 +348,12 @@ class StridedSliceAssignOp : public OpKernel { // scalar shape // Handle general dimensions -#define HANDLE_DIM(NDIM) \ - if (processing_dims == NDIM) { \ - HandleStridedSliceAssignCase<Device, T, NDIM>()( \ - context, begin, end, strides, processing_shape, is_simple_slice, \ - &old_lhs); \ - return; \ +#define HANDLE_DIM(NDIM) \ + if (processing_dims == NDIM) { \ + HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \ + strides, processing_shape, \ + is_simple_slice, old_lhs); \ + return; \ } HANDLE_DIM(0); HANDLE_DIM(1); |