diff options
author | 2016-07-14 11:58:50 -0800 | |
---|---|---|
committer | 2016-07-14 13:03:06 -0700 | |
commit | 4f63d69d562053a39b4ea25253a2e571dcec2554 (patch) | |
tree | 9f60c8949868ad985de6a6a49d7bf7681cfeee56 /tensorflow/core/kernels/strided_slice_op.cc | |
parent | 333e0cf1f306652032abf66982b53fada43b368a (diff) |
Enable strided slice op as default slicing.
This allows more rich slicing of tensors i.e.
tf.assign(var, foo[-5:-10:-1, ..., 3:5, tf.newaxis])
This does not include lvalue support for assigning to a slice.
It also does not include advanced indexing foo[bar] where bar is
a tensor.
Fixed bug in implementation where num_elements was used
instead of dims for an optimization code path.
Also make supervisor have an easier to read error message
for dim mismatches.
Change: 127463353
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op.cc | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 9f0f6c0ef4..aa4c73b490 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -325,23 +325,24 @@ class StridedSliceOp : public OpKernel { if (processing_shape.num_elements() > 0) { // Optimization #3, slice has stride 1 in all dimensions // Optimization #3A, slice has only two dimensions - // TODO(aselle): Here we are restricting to processing_shape being - // 2D. this isn't strictly necessary, but I don't want to blow up - // the code gen size, because to shape<> you need static NDIM and T + // TODO(aselle): Here we are restricting to processing_shape and + // final_shape being 2D. This isn't strictly necessary, but I don't + // want to blow up code gen size, because to shape<> you need static + // NDIM and T if (is_simple_slice && std::is_same<Device, CPUDevice>::value && - input_dims == 2 && processing_shape.num_elements() == 2 && + input_dims == 2 && processing_shape.dims() == 2 && + final_shape.dims() == 2 && DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - auto input = context->input(0).tensor<T, 2>(); + auto in = input.tensor<T, 2>(); auto output = result->tensor<T, 2>(); // TODO(agarwal): Consider multi-threading if size[0] is large for (int row_in = begin[0], row_out = 0; row_in < end[0]; ++row_in, ++row_out) { if (row_in + 1 < end[0]) { port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0)); - port::prefetch<port::PREFETCH_HINT_T0>( - &input(row_in + 1, begin[1])); + port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1])); } - memcpy(&output(row_out, 0), &input(row_in, begin[1]), + memcpy(&output(row_out, 0), &in(row_in, begin[1]), (end[1] - begin[1]) * sizeof(T)); } return; |