diff options
author | 2018-03-09 22:49:30 -0800 | |
---|---|---|
committer | 2018-03-09 22:53:21 -0800 | |
commit | 2cd50a9fd2900c2bf7e74a7795823254d5383fb4 (patch) | |
tree | 8f08295309b573b8ec3db864ae94a1dc92cceebd | |
parent | 3b0a27549dd2f1a32526cb77ec7ff407d0fc315f (diff) |
[XLA] Speed up colocated buffer merging.
PiperOrigin-RevId: 188581202
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index fb18c9d828..dbe45e932c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1339,26 +1339,35 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - for (auto& buffer_a : colocated_buffer_sets[i]) { - for (auto& buffer_b : colocated_buffer_sets[j]) { - // Do not merge if the set includes live outs or entry parameters. - if (buffer_liveness.MaybeLiveOut(*buffer_a) || - is_entry_parameter(*buffer_a) || - buffer_liveness.MaybeLiveOut(*buffer_b) || - is_entry_parameter(*buffer_b)) { + // Do not merge if one of the sets includes live outs or entry parameters. + for (int64 key : {i, j}) { + for (auto& buffer : colocated_buffer_sets[key]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer)) { return true; } - // Do not merge if the buffers interfere with each other. + } + } + + // Colocated sets satisfy the invariant that all buffers within a set have + // the same size. That means we need to check whether the size is the same + // between the two sets, but also that it's enough to look at just one + // buffer within each set. + if (buffer_size(**colocated_buffer_sets[i].begin()) != + buffer_size(**colocated_buffer_sets[j].begin())) { + return true; + } + + // Do not merge if some pair of buffers interferes with each other. + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { if (buffer_a->id() != buffer_b->id() && buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { return true; } - // Do not merge if the buffer sizes are different. - if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { - return true; - } } } + return false; }; |