diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-10-09 00:56:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 01:01:04 -0700 |
commit | a198ca7d9bbc752a322c59b9a30519eab1b6730c (patch) | |
tree | f2fa6dfb5760b228daa7a80c534e819ccf08cd75 | |
parent | 129bb5e845ccb2ab6339e85d39545800dac6ca33 (diff) |
Enable support for PRED values in KeyValueSort for the HloEvaluator.
PiperOrigin-RevId: 216315110
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index eec8d242fa..6cba46135c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -1279,7 +1280,9 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort, return SafeLess<KeyType>(a.first, b.first); }); std::vector<KeyType> result_keys; - std::vector<ValueType> result_values; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector<bool>. + absl::InlinedVector<ValueType, 10> result_values; for (const auto& key_value : key_value_vector) { result_keys.push_back(key_value.first); result_values.push_back(key_value.second); @@ -1316,6 +1319,9 @@ StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { + case PRED: + return EvaluateSortInternal<KeyType, bool>(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal<KeyType, float>(sort, keys_literal, values_literal); |