diff options
author | 2018-06-27 15:05:11 -0700 | |
---|---|---|
committer | 2018-06-27 15:08:36 -0700 | |
commit | 394add116efd9839e1be5342c085e6510c265687 (patch) | |
tree | 3b208b40424bec1479fa9fb8647f16c48c1969c1 /tensorflow/compiler/xla/service/bfloat16_propagation.cc | |
parent | eb15c736ce07a92b02ba579e83733b909020828b (diff) |
[XLA] Use subshape pointers as map keys in BFloat16Propagation.
Using simple keys is more efficient.
PiperOrigin-RevId: 202377039
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_propagation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/bfloat16_propagation.cc | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ee6b6f69b9..ff6d5027ef 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { - for (const auto& index : root_changes_it->second) { + for (const auto& entry : root_changes_it->second) { for (const HloValue* value : - dataflow_->GetValueSet(root, index).values()) { + dataflow_->GetValueSet(root, entry.second).values()) { changed_root_buffers.insert(value); } } @@ -802,9 +802,8 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) { // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { - auto shape = change.first->mutable_shape(); - for (const auto& index : change.second) { - auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + for (const auto& entry : change.second) { + auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); changed_ = true; @@ -833,8 +832,8 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) { PrimitiveType BFloat16Propagation::OutputTypeAfterChange( HloInstruction* hlo, const ShapeIndex& index) const { - PrimitiveType type_on_hlo = - ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index); + const PrimitiveType type_on_hlo = subshape->element_type(); if (type_on_hlo != F32) { return type_on_hlo; } @@ -842,7 +841,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange( if (it == changes_to_bf16_.end()) { return type_on_hlo; } - return ContainsKey(it->second, index) ? BF16 : F32; + return ContainsKey(it->second, subshape) ? BF16 : F32; } PrimitiveType BFloat16Propagation::ValueTypeAfterChange( @@ -856,14 +855,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { if (target_type == BF16) { auto& entry = changes_to_bf16_[hlo]; - entry.insert(index); + entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index), + index); } else { CHECK_EQ(target_type, F32); auto it = changes_to_bf16_.find(hlo); if (it == changes_to_bf16_.end()) { return; } - it->second.erase(index); + it->second.erase( + ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index)); } } |