aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-10-09 00:56:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 01:01:04 -0700
commita198ca7d9bbc752a322c59b9a30519eab1b6730c (patch)
treef2fa6dfb5760b228daa7a80c534e819ccf08cd75
parent129bb5e845ccb2ab6339e85d39545800dac6ca33 (diff)
Enable support for PRED values in KeyValueSort for the HloEvaluator.
PiperOrigin-RevId: 216315110
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index eec8d242fa..6cba46135c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/index_util.h"
@@ -1279,7 +1280,9 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
return SafeLess<KeyType>(a.first, b.first);
});
std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
+ // We use a InlinedVector here because we need to convert it to an
+ // absl::Span later, and this would not work with std::vector<bool>.
+ absl::InlinedVector<ValueType, 10> result_values;
for (const auto& key_value : key_value_vector) {
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
@@ -1316,6 +1319,9 @@ StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
+ case PRED:
+ return EvaluateSortInternal<KeyType, bool>(sort, keys_literal,
+ values_literal);
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
values_literal);