diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/sort_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/sort_ops.cc | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01cc..45f03d8c21 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), context->Input("values")); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow |