aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_propagation.cc
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-06-27 15:05:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 15:08:36 -0700
commit394add116efd9839e1be5342c085e6510c265687 (patch)
tree3b208b40424bec1479fa9fb8647f16c48c1969c1 /tensorflow/compiler/xla/service/bfloat16_propagation.cc
parenteb15c736ce07a92b02ba579e83733b909020828b (diff)
[XLA] Use subshape pointers as map keys in BFloat16Propagation.
Using simple keys is more efficient. PiperOrigin-RevId: 202377039
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_propagation.cc')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc21
1 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index ee6b6f69b9..ff6d5027ef 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
- for (const auto& index : root_changes_it->second) {
+ for (const auto& entry : root_changes_it->second) {
for (const HloValue* value :
- dataflow_->GetValueSet(root, index).values()) {
+ dataflow_->GetValueSet(root, entry.second).values()) {
changed_root_buffers.insert(value);
}
}
@@ -802,9 +802,8 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
// Apply the changes in changes_to_bf16_.
for (auto& change : changes_to_bf16_) {
- auto shape = change.first->mutable_shape();
- for (const auto& index : change.second) {
- auto subshape = ShapeUtil::GetMutableSubshape(shape, index);
+ for (const auto& entry : change.second) {
+ auto subshape = entry.first;
CHECK_EQ(subshape->element_type(), F32);
subshape->set_element_type(BF16);
changed_ = true;
@@ -833,8 +832,8 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
HloInstruction* hlo, const ShapeIndex& index) const {
- PrimitiveType type_on_hlo =
- ShapeUtil::GetSubshape(hlo->shape(), index).element_type();
+ Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
+ const PrimitiveType type_on_hlo = subshape->element_type();
if (type_on_hlo != F32) {
return type_on_hlo;
}
@@ -842,7 +841,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
if (it == changes_to_bf16_.end()) {
return type_on_hlo;
}
- return ContainsKey(it->second, index) ? BF16 : F32;
+ return ContainsKey(it->second, subshape) ? BF16 : F32;
}
PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
@@ -856,14 +855,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
if (target_type == BF16) {
auto& entry = changes_to_bf16_[hlo];
- entry.insert(index);
+ entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
+ index);
} else {
CHECK_EQ(target_type, F32);
auto it = changes_to_bf16_.find(hlo);
if (it == changes_to_bf16_.end()) {
return;
}
- it->second.erase(index);
+ it->second.erase(
+ ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
}
}