diff options
author | Mark Heffernan <meheff@google.com> | 2017-09-26 15:50:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-26 15:54:45 -0700 |
commit | 079061306d4f58295e48b452818875c6a9bdbfaa (patch) | |
tree | c7e66135710fb9311a93f25048deb8e3fa7bb38f /tensorflow/compiler/xla/service/tuple_simplifier.cc | |
parent | 725206e677a9f1e343319293a347862335ff776b (diff) |
Add TupleSimplifier pass which collapses structures of Tuple and GetTupleElement instructions.
PiperOrigin-RevId: 170122192
Diffstat (limited to 'tensorflow/compiler/xla/service/tuple_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/tuple_simplifier.cc | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc new file mode 100644 index 0000000000..f92116ec19 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" + +#include <queue> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr<bool> TupleSimplifier::Run(HloModule* module) { + // Initially add all GTE and Tuple instructions to the worklist. + std::queue<HloInstruction*> worklist; + for (auto& computation : module->computations()) { + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kTuple || + instruction->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(instruction.get()); + } + } + } + + bool changed = false; + while (!worklist.empty()) { + HloInstruction* instruction = worklist.front(); + worklist.pop(); + + if (instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction()) { + // Tuple simplification works by replacing users of optimized away + // instructions with a simpler form. If there is no user of the + // instruction (including being the root), then there is nothing to do. + continue; + } + + if (instruction->opcode() == HloOpcode::kTuple) { + // Collapse the following structure into just 'Tuple-shaped Op': + // + // Tuple-shaped Op + // | + // +-----+-----+ + // | | | + // GTE GTE GTE + // | | | + // +-----+-----+ + // | + // Tuple + // + HloInstruction* top_tuple = nullptr; + bool can_simplify = true; + for (int64 operand_number = 0; + operand_number < instruction->operand_count(); ++operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->tuple_index() != operand_number) { + can_simplify = false; + break; + } + + if (top_tuple == nullptr) { + top_tuple = operand->mutable_operand(0); + } else if (top_tuple != operand->operand(0)) { + can_simplify = false; + break; + } + } + if (can_simplify && top_tuple != nullptr) { + changed = true; + TF_RETURN_IF_ERROR(instruction->parent()->ReplaceUsesOfInstruction( + instruction, top_tuple)); + // No need to add anything to the worklist. + } + } else { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // If possible replace a GTE with the operation which produces the + // element. For example, replace uses of GTE with below with just 'Op' + // (assuming 'Op' is at the index of the GTE instruction): + // + // ... Op ... + // \ | / + // Tuple + // | + // GTE + if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { + changed = true; + HloInstruction* element_source = + instruction->mutable_operand(0)->mutable_operand( + instruction->tuple_index()); + TF_RETURN_IF_ERROR(instruction->parent()->ReplaceUsesOfInstruction( + instruction, element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); + } + } + } + } + } + + return changed; +} + +} // namespace xla |