aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/sort_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index aaeeae01cc..45f03d8c21 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel {
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- context->SetOutput(0, xla::Sort(context->Input(0)));
+ context->SetOutput(0, xla::Sort(context->Input("input")));
}
};
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
+class XlaKeyValueSortOp : public XlaOpKernel {
+ public:
+ explicit XlaKeyValueSortOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp result =
+ xla::Sort(context->Input("keys"), context->Input("values"));
+ context->SetOutput(0, xla::GetTupleElement(result, 0));
+ context->SetOutput(1, xla::GetTupleElement(result, 1));
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
+
} // namespace
} // namespace tensorflow