aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/util.h
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-02-21 10:29:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 10:34:01 -0800
commitf604ba67ef3340c29afac74162659f1cf0c9d557 (patch)
tree2f33573882306b8b004ebe08f8af47226ad4bd07 /tensorflow/compiler/xla/util.h
parent113fce8885c80d6897a58dd8e0747b964e8cb113 (diff)
[XLA] Add FindInstruction and FindComputation helpers to HloTestBase.
These are useful for tests that create HLOs and then search for a particular computation/instruction. While we're at it, add a c_find_if utility and fix up the (lack of) perfect forwarding in some of our other c_foo utilities. PiperOrigin-RevId: 186482111
Diffstat (limited to 'tensorflow/compiler/xla/util.h')
-rw-r--r--tensorflow/compiler/xla/util.h15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 46ec7af542..e14c8cefa1 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -427,8 +427,9 @@ std::vector<std::pair<int64, int64>> CommonFactors(
string SanitizeFileName(string file_name);
template <typename Container, typename Predicate>
-bool c_all_of(Container container, Predicate predicate) {
- return std::all_of(std::begin(container), std::end(container), predicate);
+bool c_all_of(Container container, Predicate&& predicate) {
+ return std::all_of(std::begin(container), std::end(container),
+ std::forward<Predicate>(predicate));
}
template <typename InputContainer, typename OutputIterator,
@@ -461,8 +462,9 @@ void c_sort(InputContainer& input_container) {
}
template <class InputContainer, class Comparator>
-void c_sort(InputContainer& input_container, Comparator comparator) {
- std::sort(std::begin(input_container), std::end(input_container), comparator);
+void c_sort(InputContainer& input_container, Comparator&& comparator) {
+ std::sort(std::begin(input_container), std::end(input_container),
+ std::forward<Comparator>(comparator));
}
template <typename Sequence, typename T>
@@ -480,6 +482,11 @@ template <typename C>
auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) {
return std::adjacent_find(std::begin(c), std::end(c));
}
+
+template <typename C, typename Pred>
+auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) {
+ return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
+}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \