aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-09-25 10:14:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:18:50 -0700
commitfaee2023f9764de44a804c3208be6f68dac04917 (patch)
treecde9610e0649cf68f0fc39f9f5104b478b550541 /tensorflow/compiler
parent954d6a0ace9b96cdd54659b99e9378a1138a7266 (diff)
[XLA] Make HloComputation::instruction_count() constant-time.
* Use a FlatMap for instruction_iterators_, and actually remove elements from it (which is cheap for a FlatMap). * Use the size of the map (which is O(1)) rather than the size of the list (which is O(n)) for instruction_count(). PiperOrigin-RevId: 214459259
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
2 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index e9e70b2c57..0e5920af7a 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -272,10 +272,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
- TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
- auto inst_it = instruction_iterators_.at(instruction);
- (*inst_it)->set_parent(nullptr);
- instructions_.erase(inst_it);
+ auto inst_it = instruction_iterators_.find(instruction);
+ TF_RET_CHECK(inst_it != instruction_iterators_.end());
+ (*inst_it->second)->set_parent(nullptr);
+ instructions_.erase(inst_it->second);
+ instruction_iterators_.erase(inst_it);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index e7c98aae23..936a53bd7e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -227,7 +227,7 @@ class HloComputation {
void UpdateReachabilityThroughInstruction(
const HloInstruction* instruction, HloReachabilityMap* reachability_map);
- int64 instruction_count() const { return instructions_.size(); }
+ int64 instruction_count() const { return instruction_iterators_.size(); }
// Creates and returns a list of the embedded computations called by this
// computation. This includes all embedded computations called directly or
@@ -439,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;