diff options
author | 2018-06-25 17:27:30 -0700 | |
---|---|---|
committer | 2018-06-25 17:30:25 -0700 | |
commit | 89013c6f76568736cd6d8395f73db53045303412 (patch) | |
tree | aa16b75243df34b120eb0799d8c7534b494fcfad /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | ee8703f342269dca881c17c6db3177355fcd18c7 (diff) |
[XLA] Avoid fusion nodes to have duplicate operands during replacing uses.
PiperOrigin-RevId: 202049336
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index a015d791ce..e2f43f5810 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -1208,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation); } +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices; + std::vector<int> operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice<HloInstruction*> parameters) |