diff options
author | 2018-08-20 09:07:12 -0700 | |
---|---|---|
committer | 2018-08-20 09:11:11 -0700 | |
commit | b7127df9da79b8c3c017f5de1b6f571eb3ff487b (patch) | |
tree | fb4aac4af9bede8628352cfa7cd398df9c13fbf9 | |
parent | c7d36f4fb2e074f76c7a5869d96b9067bead6909 (diff) |
[XLA] add SortKeyVal to the local Python client.
This operation corresponds to the version of the Sort HLO with three arguments, but we give it a separate name (SortKeyVal instead of Sort) for compatibility with SWIG.
PiperOrigin-RevId: 209427551
4 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 8246f76d34..212439dec8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -575,6 +575,16 @@ StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } +LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { + return xla::Sort(operand.op(), tensorflow::gtl::nullopt, dimension); +} + +LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, + int64 dimension) { + return xla::Sort(keys.op(), values.op(), dimension); +} + StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -640,7 +650,6 @@ _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) -_FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a568c24c63..5f9078ab84 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -301,6 +301,11 @@ class LocalComputationBuilder { StatusOr<bool> IsConstant(const LocalOp& operand); + LocalOp Sort(const LocalOp& operand, int64 dimension); + + LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, + int64 dimension); + StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ @@ -357,7 +362,6 @@ class LocalComputationBuilder { _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) - _FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 5d5a955bfe..fa5d75908f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1011,6 +1011,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::SortKeyVal; %unignore xla::swig::LocalComputationBuilder::Sqrt; %unignore xla::swig::LocalComputationBuilder::Rsqrt; %unignore xla::swig::LocalComputationBuilder::Square; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a2c6fc344d..fa4366ff07 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -105,7 +105,6 @@ _UNARY_OPS = [ 'Square', 'Reciprocal', 'Neg', - 'Sort', 'Erf', 'Erfc', 'ErfInv', @@ -1218,6 +1217,14 @@ class ComputationBuilder(object): lhs_dilation, rhs_dilation, dimension_numbers) + def Sort(self, operand, dimension=-1): + """Enqueues a sort operation onto the computation.""" + return self._client.Sort(operand, dimension) + + def SortKeyVal(self, keys, values, dimension=-1): + """Enqueues a key-value sort operation onto the computation.""" + return self._client.SortKeyVal(keys, values, dimension) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. |