From a198ca7d9bbc752a322c59b9a30519eab1b6730c Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 9 Oct 2018 00:56:23 -0700 Subject: Enable support for PRED values in KeyValueSort for the HloEvaluator. PiperOrigin-RevId: 216315110 --- tensorflow/compiler/xla/service/hlo_evaluator.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index eec8d242fa..6cba46135c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -1279,7 +1280,9 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, return SafeLess(a.first, b.first); }); std::vector result_keys; - std::vector result_values; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector result_values; for (const auto& key_value : key_value_vector) { result_keys.push_back(key_value.first); result_values.push_back(key_value.second); @@ -1316,6 +1319,9 @@ StatusOr EvaluateSortCurried(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { + case PRED: + return EvaluateSortInternal(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal(sort, keys_literal, values_literal); -- cgit v1.2.3