aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-26 16:34:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 16:38:52 -0700
commit1ccc394c1010a7d84b71cc193b23578d378c078b (patch)
treea3544c00cc36756cf847591ec68fdbf64e226681
parent079061306d4f58295e48b452818875c6a9bdbfaa (diff)
[TF:XLA] Extend implementation of "Slice" operator to support "begin" values that are not known statically at compile time.
Cleanup implementation of Slice. PiperOrigin-RevId: 170128580
-rw-r--r--tensorflow/compiler/tests/slice_ops_test.py28
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc148
3 files changed, 95 insertions, 82 deletions
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py
index 4ddf2ee0dc..3bf514ca91 100644
--- a/tensorflow/compiler/tests/slice_ops_test.py
+++ b/tensorflow/compiler/tests/slice_ops_test.py
@@ -18,15 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-
class SliceTest(XLATestCase):
def test1D(self):
@@ -63,6 +60,29 @@ class SliceTest(XLATestCase):
self.assertAllEqual([[[6, 5, 4, 3]]], result)
+ def test3DWithDynamicBegin(self):
+ """Tests a slice where the start offset is not known at compile time."""
+ for dtype in self.numeric_types:
+ with self.test_session():
+ i = array_ops.placeholder(dtype, shape=[3, 3, 10])
+ begin = array_ops.placeholder(dtypes.int32, shape=[3])
+ with self.test_scope():
+ o = array_ops.slice(i, begin, [1, 1, 4])
+ params = {
+ i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
+ [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]],
+ [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]],
+ [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
+ [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
+ [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]],
+ begin: [1, 2, 2]
+ }
+ result = o.eval(feed_dict=params)
+
+ self.assertAllEqual([[[6, 5, 4, 3]]], result)
class StridedSliceTest(XLATestCase):
@@ -80,7 +100,7 @@ class StridedSliceTest(XLATestCase):
self.assertAllEqual([2, 4], result)
- def test1DNegtiveStride(self):
+ def test1DNegativeStride(self):
for dtype in self.numeric_types:
with self.test_session():
i = array_ops.placeholder(dtype, shape=[10])
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 170a33e003..ad0397a3d9 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -78,7 +78,6 @@ Status BackwardsConstAnalysis(const Graph& g,
{"ResourceStridedSliceAssign", "strides"},
{"Reverse", "dims"},
{"ReverseV2", "axis"},
- {"Slice", "begin"},
{"Slice", "size"},
{"SpaceToBatch", "paddings"},
{"SpaceToBatchND", "block_shape"},
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 482c54a40c..fbe8c78d8f 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -35,88 +35,82 @@ class SliceOp : public XlaOpKernel {
explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- bool is_identity = true;
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape begin_tensor_shape = ctx->InputShape(1);
+ const TensorShape size_tensor_shape = ctx->InputShape(2);
+
+ OP_REQUIRES(
+ ctx,
+ IsLegacyVector(begin_tensor_shape) &&
+ IsLegacyVector(size_tensor_shape) &&
+ begin_tensor_shape.num_elements() == input_shape.dims() &&
+ size_tensor_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_shape.dims(), ", but got shapes ",
+ begin_tensor_shape.DebugString(), " and ",
+ size_tensor_shape.DebugString(), " instead."));
+
+ const int input_dims = input_shape.dims();
+
std::vector<int64> begin;
std::vector<int64> size;
- SharedValidation(ctx, &is_identity, &begin, &size);
- if (!ctx->status().ok()) return;
-
- if (is_identity) {
- VLOG(1) << "Slice identity";
- ctx->SetOutput(0, ctx->Input(0));
- return;
- }
-
- // slice will be an empty handle if the output has no elements.
- CHECK_EQ(begin.size(), size.size());
- std::vector<int64> limits;
- limits.reserve(begin.size());
- for (int i = 0; i < begin.size(); ++i) {
- limits.push_back(begin[i] + size[i]);
- }
- std::vector<int64> strides(begin.size(), 1);
- ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits,
- strides));
- }
-
- private:
- void SharedValidation(XlaOpKernelContext* ctx, bool* is_identity,
- std::vector<int64>* begin, std::vector<int64>* size);
-};
-
-void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity,
- std::vector<int64>* begin,
- std::vector<int64>* size) {
- const TensorShape input_shape = ctx->InputShape(0);
- const TensorShape begin_tensor_shape = ctx->InputShape(1);
- const TensorShape size_tensor_shape = ctx->InputShape(2);
-
- OP_REQUIRES(
- ctx,
- IsLegacyVector(begin_tensor_shape) && IsLegacyVector(size_tensor_shape) &&
- begin_tensor_shape.num_elements() == input_shape.dims() &&
- size_tensor_shape.num_elements() == input_shape.dims(),
- errors::InvalidArgument(
- "Expected begin and size arguments to be 1-D tensors of size ",
- input_shape.dims(), ", but got shapes ",
- begin_tensor_shape.DebugString(), " and ",
- size_tensor_shape.DebugString(), " instead."));
-
- const int input_dims = input_shape.dims();
-
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, begin));
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, size));
- for (int i = 0; i < input_dims; ++i) {
- if ((*size)[i] == -1) {
- // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
- (*size)[i] = input_shape.dim_size(i) - (*begin)[i];
- }
- }
-
- *is_identity = true;
- for (int i = 0; i < input_dims; ++i) {
- int64 b = (*begin)[i];
- int64 s = (*size)[i];
- if (input_shape.dim_size(i) == 0) {
- OP_REQUIRES(ctx, b == 0 && s == 0,
- errors::InvalidArgument(
- "Expected begin[", i, "] == 0 (got ", b, ") and size[", i,
- "] == 0 ", "(got ", s, ") when ", "input_shape.dim_size(",
- i, ") == 0"));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
+ if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
+ // `begin` is a compile-time constant.
+ for (int i = 0; i < input_dims; ++i) {
+ if (size[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ size[i] = input_shape.dim_size(i) - begin[i];
+ }
+ }
+
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = begin[i];
+ int64 s = size[i];
+ if (input_shape.dim_size(i) == 0) {
+ OP_REQUIRES(ctx, b == 0 && s == 0,
+ errors::InvalidArgument(
+ "Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s, ") when ",
+ "input_shape.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(ctx, 0 <= b && b <= input_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ }
+
+ std::vector<int64> limits;
+ limits.reserve(begin.size());
+ for (int i = 0; i < begin.size(); ++i) {
+ limits.push_back(begin[i] + size[i]);
+ }
+ std::vector<int64> strides(begin.size(), 1);
+ ctx->SetOutput(
+ 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides));
} else {
- OP_REQUIRES(
- ctx, 0 <= b && b <= input_shape.dim_size(i),
- errors::InvalidArgument("Expected begin[", i, "] in [0, ",
- input_shape.dim_size(i), "], but got ", b));
- OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i),
- errors::InvalidArgument("Expected size[", i, "] in [0, ",
- input_shape.dim_size(i) - b,
- "], but ", "got ", s));
+ // `begin` is not a compile-time constant.
+ for (int i = 0; i < input_dims; ++i) {
+ OP_REQUIRES(ctx, 0 <= size[i],
+ errors::InvalidArgument(
+ "XLA compilation of Slice operator with negative sizes "
+ "requires that 'begin' is a compile-time constant."));
+ OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_shape.dim_size(i), "], but ",
+ "got ", size[i]));
+ }
+ ctx->SetOutput(
+ 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size));
}
- const bool take_all = (b == 0) && (s == input_shape.dim_size(i));
- (*is_identity) &= take_all;
}
-}
+};
REGISTER_XLA_OP(Name("Slice"), SliceOp);