aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 09:29:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 09:32:36 -0700
commit5a43e01ef0f8cb86d836a4d1c08a246630e26f8c (patch)
tree20af6c3c98b4527e7c9d38909b900b86cf395e52 /tensorflow/compiler
parentd258207f1583df4faa452265b051879af6c15dac (diff)
Update XlaSort to match the underlying HLO.
PiperOrigin-RevId: 215917470
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc17
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc23
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc2
5 files changed, 63 insertions, 9 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index dbf4beb693..57f0ab7a9e 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,13 +48,29 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
+ def testKeyValueSort(self):
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for key_type in supported_types.intersection(self.numeric_types):
+ for value_type in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=key_type)
+ np.random.shuffle(x)
+ y = (-x).astype(value_type)
+ self._assertOpOutputMatchesExpected(
+ xla.key_value_sort, [x, y],
+ expected=[
+ np.arange(101, dtype=key_type),
+ -np.arange(101, dtype=value_type)
+ ])
+
def testTopK(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index aaeeae01cc..45f03d8c21 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel {
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- context->SetOutput(0, xla::Sort(context->Input(0)));
+ context->SetOutput(0, xla::Sort(context->Input("input")));
}
};
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
+class XlaKeyValueSortOp : public XlaOpKernel {
+ public:
+ explicit XlaKeyValueSortOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp result =
+ xla::Sort(context->Input("keys"), context->Input("values"));
+ context->SetOutput(0, xla::GetTupleElement(result, 0));
+ context->SetOutput(1, xla::GetTupleElement(result, 1));
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 733eeed3c6..557911553d 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -354,12 +354,33 @@ Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
-Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
+Sorts a tensor. Currently only sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
+REGISTER_OP("XlaKeyValueSort")
+ .Input("keys: K")
+ .Input("values: V")
+ .Output("sorted_keys: K")
+ .Output("sorted_values: V")
+ .Attr("K: realnumbertype")
+ .Attr("V: type")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA Sort operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#sort
+.
+
+Sorts a tensor. Currently only sorts in ascending order are supported.
+
+keys: A `Tensor` of type K.
+values: A `Tensor` of type V.
+sorted_keys: A `Tensor` of type K.
+sorted_values: A `Tensor` of type V.
+)doc");
+
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 27dd18a9bb..bc7924c371 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast
def broadcast(x, dims, name=None):
x = ops.convert_to_tensor(x)
- shape = array_ops.concat(
- [constant_op.constant(dims),
- array_ops.shape(x)], axis=0)
+ shape = array_ops.concat([constant_op.constant(dims),
+ array_ops.shape(x)],
+ axis=0)
return array_ops.broadcast_to(x, shape, name=name)
@@ -332,12 +332,13 @@ def reduce_window(operand,
init: a scalar tensor representing the initial value for the reduction
reducer: a reduction function that combines a pair of scalars.
window_dimensions: shape of the window, as a list of integers
- window_strides: inter-window strides, as a list of integers. Optional;
- if omitted, defaults to strides of 1.
+ window_strides: inter-window strides, as a list of integers. Optional; if
+ omitted, defaults to strides of 1.
padding: padding to apply to 'operand'. List of (low, high) pairs of
integers that specify the padding to apply before and after each
dimension. Optional; if omitted, defaults to no padding.
name: the operator name, or None.
+
Returns:
A tensor that represents the output of the reduce_window operator.
"""
@@ -377,4 +378,5 @@ def slice(x, start_dims, limit_dims, strides):
sort = gen_xla_ops.xla_sort
+key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index b5498bb936..c22ee03388 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -548,6 +548,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
+ case HloOpcode::kSort:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
break;
@@ -1153,7 +1154,6 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
-
for (auto* computation : module->computations()) {
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));