aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-31 05:37:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 05:42:10 -0700
commit2826d123a017bc5f1a2cc7b969e275c5a63c326c (patch)
tree8204b3e1a0103a7a183a28e1976fc68491a8634b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parent3bec2640dcbd251f4eb2517d9ae7d8909886375d (diff)
Allow Sort to share the buffer with the operand if it is the only user.
The BitonicSort algorithm works in-place, so we can make use of that. On GPU, so far we copied the operand to the output and then performed the algorithm in-place. Now, we may not need to do this anymore if we see that the buffer is shared. Also, we now only need device-to-device copies in case the buffer is not shared, because constants are now also assigned a buffer. PiperOrigin-RevId: 206745686
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 1abfcb7703..bbfb0c253f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -1084,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
+ if (user->opcode() == HloOpcode::kSort) {
+ // Only valid if there are no other users.
+ if (operand->users().size() != 1) {
+ return false;
+ }
+ // If we only sort keys, the output of sort is not a tuple, so we can always
+ // share the buffer.
+ if (user->operand_count() == 1) {
+ return true;
+ }
+ CHECK(!user_index.empty());
+ // Only share with the right tuple element buffer.
+ std::vector<int64> operand_indices = user->OperandIndices(operand);
+ return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
+ }
if (user->opcode() == HloOpcode::kCall) {
// Get all uses of value defined by 'operand' at 'operand_index'.
const auto& uses = GetValueDefinedAt(operand, operand_index).uses();