aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc144
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h123
2 files changed, 141 insertions, 126 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 06b6d5b559..b91b2406e2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1173,80 +1173,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
- << rank;
TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
- // We need to sort and array of keys and an array of values, where the
+ // We need to sort an array of keys and an array of values, where the
// sorted order of the values is determined by the keys. The simplest(?)
// way to do this is to go to an array-of-pairs representation, sort the
// array using the keys, and then go back to pair-of-arrays.
VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
- auto sort_r1 = [](const Literal& keys_literal,
- const Literal& values_literal) {
- const auto& keys_data = keys_literal.data<KeyType>();
- const auto& values_data = values_literal.data<ValueType>();
-
- using kv_pair = std::pair<KeyType, ValueType>;
- std::vector<kv_pair> key_value_vector;
- CHECK_EQ(keys_data.size(), values_data.size());
- key_value_vector.reserve(keys_data.size());
- for (int i = 0; i < keys_data.size(); ++i) {
- key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
- }
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
- std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
- for (const auto& key_value : key_value_vector) {
- result_keys.push_back(key_value.first);
- result_values.push_back(key_value.second);
- }
- Literal result_keys_literal(keys_literal.shape());
- result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
- Literal result_values_literal(values_literal.shape());
- result_values_literal.PopulateR1(
- absl::Span<const ValueType>(result_values));
- return std::make_pair(std::move(result_keys_literal),
- std::move(result_values_literal));
- };
-
- Literal result_tuple;
- if (rank == 1) {
- auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple =
- LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal keys_result_literal(keys_literal.shape());
- Literal values_result_literal(values_literal.shape());
- int64 r1_length = keys_literal.shape().dimensions(1);
- for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- TF_ASSIGN_OR_RETURN(auto values_r1_slice,
- values_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
- TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first.Reshape({1, r1_length}));
- TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
- sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
- sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
- }
- result_tuple =
- LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
+ if (rank == 0) {
+ // Nothing to sort.
+ return LiteralUtil::MakeTuple({&keys_literal, &values_literal});
}
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys_literal.shape(), zero_base,
+ AsInt64Slice(keys_literal.shape().dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the keys and values literals that correspond to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto keys_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& keys_data = keys_to_sort.data<KeyType>();
+ TF_ASSIGN_OR_RETURN(auto values_to_sort,
+ values_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& values_data = values_to_sort.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ Literal sorted_keys(ShapeUtil::MakeShape(
+ keys_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal sorted_values(ShapeUtil::MakeShape(
+ values_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_values.PopulateR1(absl::Span<const ValueType>(result_values));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped,
+ sorted_keys.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys_reshaped, start_indices, indices, slice_dimensions));
+ TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped,
+ sorted_values.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+
+ Literal result_tuple;
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
@@ -1292,15 +1297,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort,
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- const int64 sort_dim = sort->dimensions(0);
- const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
- if (sort_dim != rank - 1) {
- return Unimplemented(
- "Trying to sort along dimension %d, which is not the last "
- "dimension",
- sort_dim);
- }
-
if (!ShapeUtil::IsTuple(sort->shape())) {
return DefaultAction(sort);
} else {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8fb17a0033..35391ecf8a 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -1527,47 +1538,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}