diff options
author | 2017-09-26 16:34:24 -0700 | |
---|---|---|
committer | 2017-09-26 16:38:52 -0700 | |
commit | 1ccc394c1010a7d84b71cc193b23578d378c078b (patch) | |
tree | a3544c00cc36756cf847591ec68fdbf64e226681 | |
parent | 079061306d4f58295e48b452818875c6a9bdbfaa (diff) |
[TF:XLA] Extend implementation of "Slice" operator to support "begin" values that are not known statically at compile time.
Cleanup implementation of Slice.
PiperOrigin-RevId: 170128580
-rw-r--r-- | tensorflow/compiler/tests/slice_ops_test.py | 28 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/const_analysis.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/slice_op.cc | 148 |
3 files changed, 95 insertions, 82 deletions
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 4ddf2ee0dc..3bf514ca91 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,15 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest - class SliceTest(XLATestCase): def test1D(self): @@ -63,6 +60,29 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[6, 5, 4, 3]]], result) + def test3DWithDynamicBegin(self): + """Tests a slice where the start offset is not known at compile time.""" + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + begin = array_ops.placeholder(dtypes.int32, shape=[3]) + with self.test_scope(): + o = array_ops.slice(i, begin, [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]], + begin: [1, 2, 2] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) class StridedSliceTest(XLATestCase): @@ -80,7 +100,7 @@ class StridedSliceTest(XLATestCase): self.assertAllEqual([2, 4], result) - def test1DNegtiveStride(self): + def test1DNegativeStride(self): for dtype in self.numeric_types: with self.test_session(): i = array_ops.placeholder(dtype, shape=[10]) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 170a33e003..ad0397a3d9 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -78,7 +78,6 @@ Status BackwardsConstAnalysis(const Graph& g, {"ResourceStridedSliceAssign", "strides"}, {"Reverse", "dims"}, {"ReverseV2", "axis"}, - {"Slice", "begin"}, {"Slice", "size"}, {"SpaceToBatch", "paddings"}, {"SpaceToBatchND", "block_shape"}, diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 482c54a40c..fbe8c78d8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -35,88 +35,82 @@ class SliceOp : public XlaOpKernel { explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - bool is_identity = true; + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape begin_tensor_shape = ctx->InputShape(1); + const TensorShape size_tensor_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, + IsLegacyVector(begin_tensor_shape) && + IsLegacyVector(size_tensor_shape) && + begin_tensor_shape.num_elements() == input_shape.dims() && + size_tensor_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input_shape.dims(), ", but got shapes ", + begin_tensor_shape.DebugString(), " and ", + size_tensor_shape.DebugString(), " instead.")); + + const int input_dims = input_shape.dims(); + std::vector<int64> begin; std::vector<int64> size; - SharedValidation(ctx, &is_identity, &begin, &size); - if (!ctx->status().ok()) return; - - if (is_identity) { - VLOG(1) << "Slice identity"; - ctx->SetOutput(0, ctx->Input(0)); - return; - } - - // slice will be an empty handle if the output has no elements. - CHECK_EQ(begin.size(), size.size()); - std::vector<int64> limits; - limits.reserve(begin.size()); - for (int i = 0; i < begin.size(); ++i) { - limits.push_back(begin[i] + size[i]); - } - std::vector<int64> strides(begin.size(), 1); - ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits, - strides)); - } - - private: - void SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, - std::vector<int64>* begin, std::vector<int64>* size); -}; - -void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, - std::vector<int64>* begin, - std::vector<int64>* size) { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape begin_tensor_shape = ctx->InputShape(1); - const TensorShape size_tensor_shape = ctx->InputShape(2); - - OP_REQUIRES( - ctx, - IsLegacyVector(begin_tensor_shape) && IsLegacyVector(size_tensor_shape) && - begin_tensor_shape.num_elements() == input_shape.dims() && - size_tensor_shape.num_elements() == input_shape.dims(), - errors::InvalidArgument( - "Expected begin and size arguments to be 1-D tensors of size ", - input_shape.dims(), ", but got shapes ", - begin_tensor_shape.DebugString(), " and ", - size_tensor_shape.DebugString(), " instead.")); - - const int input_dims = input_shape.dims(); - - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, begin)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, size)); - for (int i = 0; i < input_dims; ++i) { - if ((*size)[i] == -1) { - // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". - (*size)[i] = input_shape.dim_size(i) - (*begin)[i]; - } - } - - *is_identity = true; - for (int i = 0; i < input_dims; ++i) { - int64 b = (*begin)[i]; - int64 s = (*size)[i]; - if (input_shape.dim_size(i) == 0) { - OP_REQUIRES(ctx, b == 0 && s == 0, - errors::InvalidArgument( - "Expected begin[", i, "] == 0 (got ", b, ") and size[", i, - "] == 0 ", "(got ", s, ") when ", "input_shape.dim_size(", - i, ") == 0")); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size)); + if (ctx->ConstantInputAsIntVector(1, &begin).ok()) { + // `begin` is a compile-time constant. + for (int i = 0; i < input_dims; ++i) { + if (size[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + size[i] = input_shape.dim_size(i) - begin[i]; + } + } + + for (int i = 0; i < input_dims; ++i) { + int64 b = begin[i]; + int64 s = size[i]; + if (input_shape.dim_size(i) == 0) { + OP_REQUIRES(ctx, b == 0 && s == 0, + errors::InvalidArgument( + "Expected begin[", i, "] == 0 (got ", b, + ") and size[", i, "] == 0 ", "(got ", s, ") when ", + "input_shape.dim_size(", i, ") == 0")); + } else { + OP_REQUIRES(ctx, 0 <= b && b <= input_shape.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input_shape.dim_size(i), + "], but got ", b)); + OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_shape.dim_size(i) - b, + "], but ", "got ", s)); + } + } + + std::vector<int64> limits; + limits.reserve(begin.size()); + for (int i = 0; i < begin.size(); ++i) { + limits.push_back(begin[i] + size[i]); + } + std::vector<int64> strides(begin.size(), 1); + ctx->SetOutput( + 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); } else { - OP_REQUIRES( - ctx, 0 <= b && b <= input_shape.dim_size(i), - errors::InvalidArgument("Expected begin[", i, "] in [0, ", - input_shape.dim_size(i), "], but got ", b)); - OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i), - errors::InvalidArgument("Expected size[", i, "] in [0, ", - input_shape.dim_size(i) - b, - "], but ", "got ", s)); + // `begin` is not a compile-time constant. + for (int i = 0; i < input_dims; ++i) { + OP_REQUIRES(ctx, 0 <= size[i], + errors::InvalidArgument( + "XLA compilation of Slice operator with negative sizes " + "requires that 'begin' is a compile-time constant.")); + OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_shape.dim_size(i), "], but ", + "got ", size[i])); + } + ctx->SetOutput( + 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } - const bool take_all = (b == 0) && (s == input_shape.dim_size(i)); - (*is_identity) &= take_all; } -} +}; REGISTER_XLA_OP(Name("Slice"), SliceOp); |