diff options
author | 2018-09-12 15:28:47 -0700 | |
---|---|---|
committer | 2018-09-12 15:32:20 -0700 | |
commit | 32a3642ef448d93706ab22e894637b2dd0c197c7 (patch) | |
tree | c873d9ec097cc297e7032667d7d803715e3bfa52 | |
parent | 90876942a3f4403ebae7d1c9223c241e006eeaaa (diff) |
Export the XLA dynamic-slice HLO as a TF op
I need this in a subsequent CL where I'll rewrite the Slice TF op to DynamicSlice in some cases.
PiperOrigin-RevId: 212715067
-rw-r--r-- | tensorflow/compiler/tests/xla_ops_test.py | 41 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc | 60 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/ops/xla_ops.cc | 29 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/python/xla.py | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 5 |
6 files changed, 128 insertions, 19 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 0f3843dc1e..1e600c44e9 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -34,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _assertOpOutputMatchesExpected(self, op, args, expected, equality_fn=None): - with self.cached_session() as session: + with self.test_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index a3389d5b90..4af1e8b44c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - VLOG(3) << "DynamicUpdateSliceOp::Compile"; + DataType index_type = ctx->InputType("indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); - DataType index_type = input_type(2); - OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape update_shape = ctx->InputShape(1); - const TensorShape index_shape = ctx->InputShape(2); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape update_shape = ctx->InputShape("update"); + const TensorShape index_shape = ctx->InputShape("indices"); OP_REQUIRES( ctx, @@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = - xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = xla::DynamicUpdateSlice( + ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); ctx->SetOutput(0, result); } }; REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); +class DynamicSliceOp : public XlaOpKernel { + public: + explicit DynamicSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType index_type = ctx->InputType("start_indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); + CHECK(index_type == ctx->InputType("size_indices")); + + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape start_indices_shape = ctx->InputShape("start_indices"); + const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(start_indices_shape) && + start_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "start_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and start_indices has shape ", + start_indices_shape.DebugString())); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(size_indices_shape) && + size_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "size_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and size_indices has shape ", + size_indices_shape.DebugString())); + + std::vector<int64> size_indices; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + xla::XlaOp result = xla::DynamicSlice( + ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), + DynamicSliceOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 68cfdc1785..02363500ef 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -105,6 +105,35 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto. precision_config: a serialized xla::PrecisionConfig proto. )doc"); +REGISTER_OP("XlaDynamicSlice") + .Input("input: T") + .Input("start_indices: Tindices") + .Input("size_indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA DynamicSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must be rank == +1, with dimension size equal to the rank of operand. + +input: A `Tensor` of type T. + +start_indices: Rank 1 tensor of N integers containing the starting indices of + the slice for each dimension. Value must be greater than or equal to zero. + +start_indices: List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less +)doc"); + REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") .Input("update: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3626de375e..27dd18a9bb 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) -def dynamic_slice(x, starts, sizes, name=None): - # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not - # a compile-time constant. This doesn't exactly mimic the semantics of dynamic - # slice if the slice is out of bounds. - return array_ops.slice(x, starts, sizes, name=name) - - +dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice # TODO(phawkins): generalize tf.pad to support interior padding, and then remove diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index d10a504da0..2a9eaeee14 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 962c86d3a5..a3a0d10cc0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -71,6 +71,9 @@ class XlaOpKernelContext { // Returns the type of input `index`. DataType input_type(int index) const; + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + // Returns the type of input `index` as an xla::PrimitiveType. If the type // is not representable as an XLA type, sets an error status and returns // xla::PRIMITIVE_TYPE_INVALID. @@ -79,7 +82,7 @@ class XlaOpKernelContext { // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns the shape of input `name`. + // Returns the shape of input with name `name`. TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike |