aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/strided_slice_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op.cc')
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc41
1 files changed, 22 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 59fdc2262a..3e8a4c5b72 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel {
// NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
input_dims == 2 && processing_shape.dims() == 2 &&
- final_shape.dims() == 2) {
+ final_shape.dims() == 2 && new_axis_mask == 0) {
MemCpyFunctor<T> functor;
if (functor.Copy(input, begin, end, result)) {
return;
@@ -300,37 +300,40 @@ class StridedSliceAssignOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
- Tensor old_lhs;
+ Tensor* old_lhs = nullptr;
+ Tensor tmp;
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
- old_lhs = *v->tensor();
- OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
+ old_lhs = v->tensor();
+ OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
- "l-value dtype ", DataTypeString(old_lhs.dtype()),
+ "l-value dtype ", DataTypeString(old_lhs->dtype()),
" does not match r-value dtype ",
DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
- old_lhs = context->mutable_input(0, true);
+ tmp = context->mutable_input(0, true);
+ old_lhs = &tmp;
}
OP_REQUIRES_OK(
- context,
- ValidateStridedSliceOp(
- &context->input(1), &context->input(2), context->input(3),
- old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
- shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
- &is_simple_slice, &slice_dim0, &begin, &end, &strides));
+ context, ValidateStridedSliceOp(
+ &context->input(1), &context->input(2), context->input(3),
+ old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask, &processing_shape,
+ &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
+ &begin, &end, &strides));
if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
TensorShape input_shape = input.shape();
- TensorShape original_shape = old_lhs.shape();
+ TensorShape original_shape = old_lhs->shape();
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(
@@ -345,12 +348,12 @@ class StridedSliceAssignOp : public OpKernel {
// scalar shape
// Handle general dimensions
-#define HANDLE_DIM(NDIM) \
- if (processing_dims == NDIM) { \
- HandleStridedSliceAssignCase<Device, T, NDIM>()( \
- context, begin, end, strides, processing_shape, is_simple_slice, \
- &old_lhs); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (processing_dims == NDIM) { \
+ HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
+ strides, processing_shape, \
+ is_simple_slice, old_lhs); \
+ return; \
}
HANDLE_DIM(0);
HANDLE_DIM(1);