aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-06-25 17:27:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 17:30:25 -0700
commit89013c6f76568736cd6d8395f73db53045303412 (patch)
treeaa16b75243df34b120eb0799d8c7534b494fcfad /tensorflow/compiler/xla/service/hlo_instructions.cc
parentee8703f342269dca881c17c6db3177355fcd18c7 (diff)
[XLA] Avoid fusion nodes to have duplicate operands during replacing uses.
PiperOrigin-RevId: 202049336
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index a015d791ce..e2f43f5810 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -1208,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation);
}
+Status HloFusionInstruction::DeduplicateFusionOperands() {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ std::vector<int> operands_to_remove;
+ for (int i = 0; i < operand_count(); ++i) {
+ auto emplace_result = operand_indices.emplace(operand(i), i);
+ if (!emplace_result.second) {
+ TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
+ fused_parameter(emplace_result.first->second)));
+ operands_to_remove.push_back(i);
+ }
+ }
+ if (operands_to_remove.empty()) {
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ fused_instructions_computation()->RemoveUnusedParameters());
+ RemoveOperandsAtAscendingIndices(operands_to_remove);
+ return Status::OK();
+}
+
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters)