aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 09:07:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 09:11:11 -0700
commitb7127df9da79b8c3c017f5de1b6f571eb3ff487b (patch)
treefb4aac4af9bede8628352cfa7cd398df9c13fbf9
parentc7d36f4fb2e074f76c7a5869d96b9067bead6909 (diff)
[XLA] add SortKeyVal to the local Python client.
This operation corresponds to the version of the Sort HLO with three arguments, but we give it a separate name (SortKeyVal instead of Sort) for compatibility with SWIG. PiperOrigin-RevId: 209427551
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc11
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i1
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py9
4 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 8246f76d34..212439dec8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -575,6 +575,16 @@ StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
return builder_.IsConstant(operand.op());
}
+LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
+ return xla::Sort(operand.op(), tensorflow::gtl::nullopt, dimension);
+}
+
+LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
+ const LocalOp& values,
+ int64 dimension) {
+ return xla::Sort(keys.op(), values.op(), dimension);
+}
+
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
const LocalOp& operand) {
TF_ASSIGN_OR_RETURN(XlaComputation computation,
@@ -640,7 +650,6 @@ _FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
-_FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index a568c24c63..5f9078ab84 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -301,6 +301,11 @@ class LocalComputationBuilder {
StatusOr<bool> IsConstant(const LocalOp& operand);
+ LocalOp Sort(const LocalOp& operand, int64 dimension);
+
+ LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
+ int64 dimension);
+
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \
@@ -357,7 +362,6 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
- _FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 5d5a955bfe..fa5d75908f 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -1011,6 +1011,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::LocalComputationBuilder::SortKeyVal;
%unignore xla::swig::LocalComputationBuilder::Sqrt;
%unignore xla::swig::LocalComputationBuilder::Rsqrt;
%unignore xla::swig::LocalComputationBuilder::Square;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a2c6fc344d..fa4366ff07 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -105,7 +105,6 @@ _UNARY_OPS = [
'Square',
'Reciprocal',
'Neg',
- 'Sort',
'Erf',
'Erfc',
'ErfInv',
@@ -1218,6 +1217,14 @@ class ComputationBuilder(object):
lhs_dilation, rhs_dilation,
dimension_numbers)
+ def Sort(self, operand, dimension=-1):
+ """Enqueues a sort operation onto the computation."""
+ return self._client.Sort(operand, dimension)
+
+ def SortKeyVal(self, keys, values, dimension=-1):
+ """Enqueues a key-value sort operation onto the computation."""
+ return self._client.SortKeyVal(keys, values, dimension)
+
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.