aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/strided_slice_op.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2016-07-14 11:58:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-14 13:03:06 -0700
commit4f63d69d562053a39b4ea25253a2e571dcec2554 (patch)
tree9f60c8949868ad985de6a6a49d7bf7681cfeee56 /tensorflow/core/kernels/strided_slice_op.cc
parent333e0cf1f306652032abf66982b53fada43b368a (diff)
Enable strided slice op as default slicing.
This allows more rich slicing of tensors i.e. tf.assign(var, foo[-5:-10:-1, ..., 3:5, tf.newaxis]) This does not include lvalue support for assigning to a slice. It also does not include advanced indexing foo[bar] where bar is a tensor. Fixed bug in implementation where num_elements was used instead of dims for an optimization code path. Also make supervisor have an easier to read error message for dim mismatches. Change: 127463353
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op.cc')
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc17
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 9f0f6c0ef4..aa4c73b490 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -325,23 +325,24 @@ class StridedSliceOp : public OpKernel {
if (processing_shape.num_elements() > 0) {
// Optimization #3, slice has stride 1 in all dimensions
// Optimization #3A, slice has only two dimensions
- // TODO(aselle): Here we are restricting to processing_shape being
- // 2D. this isn't strictly necessary, but I don't want to blow up
- // the code gen size, because to shape<> you need static NDIM and T
+ // TODO(aselle): Here we are restricting to processing_shape and
+ // final_shape being 2D. This isn't strictly necessary, but I don't
+ // want to blow up code gen size, because to shape<> you need static
+ // NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
- input_dims == 2 && processing_shape.num_elements() == 2 &&
+ input_dims == 2 && processing_shape.dims() == 2 &&
+ final_shape.dims() == 2 &&
DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
+ auto in = input.tensor<T, 2>();
auto output = result->tensor<T, 2>();
// TODO(agarwal): Consider multi-threading if size[0] is large
for (int row_in = begin[0], row_out = 0; row_in < end[0];
++row_in, ++row_out) {
if (row_in + 1 < end[0]) {
port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(
- &input(row_in + 1, begin[1]));
+ port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1]));
}
- memcpy(&output(row_out, 0), &input(row_in, begin[1]),
+ memcpy(&output(row_out, 0), &in(row_in, begin[1]),
(end[1] - begin[1]) * sizeof(T));
}
return;