diff options
author | 2018-02-21 10:29:47 -0800 | |
---|---|---|
committer | 2018-02-21 10:34:01 -0800 | |
commit | f604ba67ef3340c29afac74162659f1cf0c9d557 (patch) | |
tree | 2f33573882306b8b004ebe08f8af47226ad4bd07 /tensorflow/compiler/xla/util.h | |
parent | 113fce8885c80d6897a58dd8e0747b964e8cb113 (diff) |
[XLA] Add FindInstruction and FindComputation helpers to HloTestBase.
These are useful for tests that create HLOs and then search for a
particular computation/instruction.
While we're at it, add a c_find_if utility and fix up the (lack of)
perfect forwarding in some of our other c_foo utilities.
PiperOrigin-RevId: 186482111
Diffstat (limited to 'tensorflow/compiler/xla/util.h')
-rw-r--r-- | tensorflow/compiler/xla/util.h | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 46ec7af542..e14c8cefa1 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -427,8 +427,9 @@ std::vector<std::pair<int64, int64>> CommonFactors( string SanitizeFileName(string file_name); template <typename Container, typename Predicate> -bool c_all_of(Container container, Predicate predicate) { - return std::all_of(std::begin(container), std::end(container), predicate); +bool c_all_of(Container container, Predicate&& predicate) { + return std::all_of(std::begin(container), std::end(container), + std::forward<Predicate>(predicate)); } template <typename InputContainer, typename OutputIterator, @@ -461,8 +462,9 @@ void c_sort(InputContainer& input_container) { } template <class InputContainer, class Comparator> -void c_sort(InputContainer& input_container, Comparator comparator) { - std::sort(std::begin(input_container), std::end(input_container), comparator); +void c_sort(InputContainer& input_container, Comparator&& comparator) { + std::sort(std::begin(input_container), std::end(input_container), + std::forward<Comparator>(comparator)); } template <typename Sequence, typename T> @@ -480,6 +482,11 @@ template <typename C> auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); } + +template <typename C, typename Pred> +auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { + return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred)); +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ |