diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-14 17:56:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-14 19:11:27 -0700 |
commit | e0d0c676ec111c711099bf89eb51278bc4493678 (patch) | |
tree | c15420e4b83c79f620d6b9f9c35bce9d3305e16a /tensorflow/compiler/xla/service/buffer_liveness.cc | |
parent | 830cde8776d9adb6bdbb2e0b3173d16780d52df7 (diff) |
Refactor logic from buffer_liveness to use in HeapSimulator.
Also added some simple tests.
Change: 150144113
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_liveness.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_liveness.cc | 128 |
1 files changed, 3 insertions, 125 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index b5a2936b67..0fe6e37c00 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -17,11 +17,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" -#include <set> #include <utility> #include <vector> #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -92,128 +92,6 @@ string BufferLiveness::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -namespace { - -// Returns false if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns true otherwise. -// Precondition: 'operand' is an operand of 'user'. -bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return false; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return true if any uses are detected at 'index', returns false otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return true; - } - } - // Return false: found no uses of 'operand' at 'index' in 'user'. - return false; - } - return true; -} - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'intruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector<std::pair<HloInstruction*, int64>> uses; - const std::vector<const LogicalBuffer*>& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). -// Returns false otherwise. -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following two qualifications: -// *) Is element-wise. -// *) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - Shape operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice - // fused root instruction. - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - for (auto& fused_param : user->fused_parameters()) { - // Find fusion parameter associated with 'operand'. - if (user->operand(fused_param->parameter_number()) != operand) { - continue; - } - // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root at operand index 0. - if (fused_param_uses.size() == 1 && - fused_param_uses[0].first == user->fused_expression_root() && - fused_param_uses[0].second == 0) { - return true; - } - break; - } - return false; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // anonymous namespace - bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); @@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { for (auto user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, + points_to_analysis())) { continue; } if (user != b.instruction() && |