aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-12 15:28:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 15:32:20 -0700
commit32a3642ef448d93706ab22e894637b2dd0c197c7 (patch)
treec873d9ec097cc297e7032667d7d803715e3bfa52
parent90876942a3f4403ebae7d1c9223c241e006eeaaa (diff)
Export the XLA dynamic-slice HLO as a TF op
I need this in a subsequent CL where I'll rewrite the Slice TF op to DynamicSlice in some cases. PiperOrigin-RevId: 212715067
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py41
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc60
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc29
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py8
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
6 files changed, 128 insertions, 19 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 0f3843dc1e..1e600c44e9 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@@ -34,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected,
equality_fn=None):
- with self.cached_session() as session:
+ with self.test_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+ def testDynamicSlice(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.dynamic_slice,
+ args=(np.arange(1000,
+ dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3, 2])),
+ expected=np.array(
+ np.array([[[573, 574], [583, 584], [593, 594]],
+ [[673, 674], [683, 684], [693, 694]]]),
+ dtype=dtype))
+
+ def testDynamicSliceWithIncorrectStartIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7]), np.array([2, 3, 4]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^start_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+ def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^size_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and size_indices has shape \[2\].*'))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index a3389d5b90..4af1e8b44c 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
- VLOG(3) << "DynamicUpdateSliceOp::Compile";
+ DataType index_type = ctx->InputType("indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
- DataType index_type = input_type(2);
- OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
- errors::InvalidArgument("index must be int32 or int64"));
-
- const TensorShape input_shape = ctx->InputShape(0);
- const TensorShape update_shape = ctx->InputShape(1);
- const TensorShape index_shape = ctx->InputShape(2);
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape update_shape = ctx->InputShape("update");
+ const TensorShape index_shape = ctx->InputShape("indices");
OP_REQUIRES(
ctx,
@@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::XlaOp result =
- xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
+ xla::XlaOp result = xla::DynamicUpdateSlice(
+ ctx->Input("input"), ctx->Input("update"), ctx->Input("indices"));
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp);
+class DynamicSliceOp : public XlaOpKernel {
+ public:
+ explicit DynamicSliceOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType index_type = ctx->InputType("start_indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
+ CHECK(index_type == ctx->InputType("size_indices"));
+
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape start_indices_shape = ctx->InputShape("start_indices");
+ const TensorShape size_indices_shape = ctx->InputShape("size_indices");
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(start_indices_shape) &&
+ start_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "start_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and start_indices has shape ",
+ start_indices_shape.DebugString()));
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(size_indices_shape) &&
+ size_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "size_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and size_indices has shape ",
+ size_indices_shape.DebugString()));
+
+ std::vector<int64> size_indices;
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices));
+ xla::XlaOp result = xla::DynamicSlice(
+ ctx->Input("input"), ctx->Input("start_indices"), size_indices);
+ ctx->SetOutput(0, result);
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"),
+ DynamicSliceOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 68cfdc1785..02363500ef 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -105,6 +105,35 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
+REGISTER_OP("XlaDynamicSlice")
+ .Input("input: T")
+ .Input("start_indices: Tindices")
+ .Input("size_indices: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA DynamicSlice operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
+.
+
+DynamicSlice extracts a sub-array from the input array at dynamic
+start_indices. The size of the slice in each dimension is passed in
+size_indices, which specify the end point of exclusive slice intervals in each
+dimension -- [start, start + size). The shape of start_indices must be rank ==
+1, with dimension size equal to the rank of operand.
+
+input: A `Tensor` of type T.
+
+start_indices: Rank 1 tensor of N integers containing the starting indices of
+ the slice for each dimension. Value must be greater than or equal to zero.
+
+start_indices: List of N integers containing the slice size for each
+ dimension. Each value must be strictly greater than zero, and start + size
+ must be less
+)doc");
+
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 3626de375e..27dd18a9bb 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
name=name)
-def dynamic_slice(x, starts, sizes, name=None):
- # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
- # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
- # slice if the slice is out of bounds.
- return array_ops.slice(x, starts, sizes, name=name)
-
-
+dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d10a504da0..2a9eaeee14 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
+DataType XlaOpKernelContext::InputType(absl::string_view name) {
+ return GetInputTensorByName(name).dtype();
+}
+
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 962c86d3a5..a3a0d10cc0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -71,6 +71,9 @@ class XlaOpKernelContext {
// Returns the type of input `index`.
DataType input_type(int index) const;
+ // Returns the type of input `name`.
+ DataType InputType(absl::string_view name);
+
// Returns the type of input `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
@@ -79,7 +82,7 @@ class XlaOpKernelContext {
// Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns the shape of input `name`.
+ // Returns the shape of input with name `name`.
TensorShape InputShape(absl::string_view name);
// Returns input `index` as a XlaOp. Unlike