diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-07-31 05:37:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 05:42:10 -0700 |
commit | 2826d123a017bc5f1a2cc7b969e275c5a63c326c (patch) | |
tree | 8204b3e1a0103a7a183a28e1976fc68491a8634b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | |
parent | 3bec2640dcbd251f4eb2517d9ae7d8909886375d (diff) |
Allow Sort to share the buffer with the operand if it is the only user.
The BitonicSort algorithm works in-place, so we can make use of that.
On GPU, so far we copied the operand to the output and then performed the algorithm in-place.
Now, we may not need to do this anymore if we see that the buffer is shared.
Also, we now only need device-to-device copies in case the buffer is not shared,
because constants are now also assigned a buffer.
PiperOrigin-RevId: 206745686
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 1abfcb7703..bbfb0c253f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1084,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( std::vector<int64> operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector<int64> operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // Get all uses of value defined by 'operand' at 'operand_index'. const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); |