aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-07-11 12:28:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 12:32:01 -0700
commit75b936e4c467af836623c7c72ff84fb0d458e5e6 (patch)
tree3241d91e184ff02f5fc3de7231f6e81dfffd618d
parentc3322543862e14482b3e108fb1b2d466641fd714 (diff)
Speed up HeapSimulator's UniqueOperandSourceBuffers.
Instead of constructing a temporary set, adding all its members to our big set, then flattening into a vector, add the members to our vector directly, then sort it and remove duplicates at the end. PiperOrigin-RevId: 161565289
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc22
1 files changed, 15 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index c662cec9c7..840be603bf 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -35,18 +35,26 @@ namespace {
std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
const HloInstruction* instruction,
const TuplePointsToAnalysis& points_to_analysis) {
- FlatSet<const LogicalBuffer*> buffers;
+ std::vector<const LogicalBuffer*> buffers;
for (const HloInstruction* operand : instruction->operands()) {
- FlatSet<const LogicalBuffer*> sources =
- points_to_analysis.GetPointsToSet(operand).CreateFlattenedSet();
- buffers.insert(sources.begin(), sources.end());
+ points_to_analysis.GetPointsToSet(operand).ForEachElement(
+ [&](const ShapeIndex& /*index*/,
+ const std::vector<const LogicalBuffer*>& points_to) {
+ buffers.insert(buffers.end(), points_to.begin(), points_to.end());
+ });
}
- std::vector<const LogicalBuffer*> sorted(buffers.begin(), buffers.end());
- std::sort(sorted.begin(), sorted.end(),
+
+ // Sort and then remove duplicates from buffers.
+ std::sort(buffers.begin(), buffers.end(),
[](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id();
});
- return sorted;
+ buffers.erase(std::unique(buffers.begin(), buffers.end(),
+ [](const LogicalBuffer* a, const LogicalBuffer* b) {
+ return a->id() == b->id();
+ }),
+ buffers.end());
+ return buffers;
}
} // namespace