From 5a43e01ef0f8cb86d836a4d1c08a246630e26f8c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Oct 2018 09:29:00 -0700 Subject: Update XlaSort to match the underlying HLO. PiperOrigin-RevId: 215917470 --- tensorflow/compiler/tests/sort_ops_test.py | 18 +++++++++++++++++- tensorflow/compiler/tf2xla/kernels/sort_ops.cc | 17 ++++++++++++++++- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 23 ++++++++++++++++++++++- tensorflow/compiler/tf2xla/python/xla.py | 12 +++++++----- tensorflow/compiler/xla/service/hlo_verifier.cc | 2 +- 5 files changed, 63 insertions(+), 9 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index dbf4beb693..57f0ab7a9e 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,13 +48,29 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + def testKeyValueSort(self): + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for key_type in supported_types.intersection(self.numeric_types): + for value_type in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01cc..45f03d8c21 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), context->Input("values")); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 733eeed3c6..557911553d 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -354,12 +354,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27dd18a9bb..bc7924c371 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -332,12 +332,13 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ @@ -377,4 +378,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index b5498bb936..c22ee03388 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -548,6 +548,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -1153,7 +1154,6 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - for (auto* computation : module->computations()) { std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); -- cgit v1.2.3