diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 144 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 123 |
2 files changed, 141 insertions, 126 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 06b6d5b559..b91b2406e2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1173,80 +1173,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort, TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " - << rank; TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; - // We need to sort and array of keys and an array of values, where the + // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the // array using the keys, and then go back to pair-of-arrays. VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - auto sort_r1 = [](const Literal& keys_literal, - const Literal& values_literal) { - const auto& keys_data = keys_literal.data<KeyType>(); - const auto& values_data = values_literal.data<ValueType>(); - - using kv_pair = std::pair<KeyType, ValueType>; - std::vector<kv_pair> key_value_vector; - CHECK_EQ(keys_data.size(), values_data.size()); - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); - } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess<KeyType>(a.first, b.first); - }); - std::vector<KeyType> result_keys; - std::vector<ValueType> result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - Literal result_keys_literal(keys_literal.shape()); - result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys)); - Literal result_values_literal(values_literal.shape()); - result_values_literal.PopulateR1( - absl::Span<const ValueType>(result_values)); - return std::make_pair(std::move(result_keys_literal), - std::move(result_values_literal)); - }; - - Literal result_tuple; - if (rank == 1) { - auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = - LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); - int64 r1_length = keys_literal.shape().dimensions(1); - for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto keys_r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - TF_ASSIGN_OR_RETURN(auto values_r1_slice, - values_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); - TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first.Reshape({1, r1_length})); - TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values, {0, 0}, {row, 0}, {1, r1_length})); - } - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); + if (rank == 0) { + // Nothing to sort. + return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); } + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); + std::vector<int64> zero_base(rank, 0); + std::vector<int64> increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys_literal.shape(), zero_base, + AsInt64Slice(keys_literal.shape().dimensions()), increment, + [&](absl::Span<const int64> indices) -> StatusOr<bool> { + // Extract a slice from the keys and values literals that correspond to + // exactly the row in dimension 'sort_dim'. + std::vector<int64> limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto keys_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& keys_data = keys_to_sort.data<KeyType>(); + TF_ASSIGN_OR_RETURN(auto values_to_sort, + values_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& values_data = values_to_sort.data<ValueType>(); + using kv_pair = std::pair<KeyType, ValueType>; + std::vector<kv_pair> key_value_vector; + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess<KeyType>(a.first, b.first); + }); + std::vector<KeyType> result_keys; + std::vector<ValueType> result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + Literal sorted_keys(ShapeUtil::MakeShape( + keys_literal.shape().element_type(), {sort_dim_elements})); + sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys)); + Literal sorted_values(ShapeUtil::MakeShape( + values_literal.shape().element_type(), {sort_dim_elements})); + sorted_values.PopulateR1(absl::Span<const ValueType>(result_values)); + std::vector<int64> slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector<int64> start_indices(rank, 0); + TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, + sorted_keys.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys_reshaped, start_indices, indices, slice_dimensions)); + TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, + sorted_values.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + + Literal result_tuple; + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } @@ -1292,15 +1297,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort, } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - const int64 sort_dim = sort->dimensions(0); - const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); - if (sort_dim != rank - 1) { - return Unimplemented( - "Trying to sort along dimension %d, which is not the last " - "dimension", - sort_dim); - } - if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8fb17a0033..35391ecf8a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include <cmath> + #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" @@ -41,7 +43,9 @@ template <typename T> using is_complex64_t = std::is_same<T, complex64>; // It's UB to use std::sort with std::less<float>, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr> @@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template <typename NativeT, - typename std::enable_if< - std::is_floating_point<NativeT>::value || - std::is_same<NativeT, bfloat16>::value>::type* = nullptr> +template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template <typename NativeT, typename std::enable_if<std::is_same< - NativeT, Eigen::half>::value>::type* = nullptr> +template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, bfloat16>::value || + std::is_same<NativeT, Eigen::half>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast<float>(a), static_cast<float>(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -1527,47 +1538,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data<ReturnT>(); - - std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - Literal result_literal(keys_literal.shape()); - result_literal.PopulateR1(absl::Span<const ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal result_literal(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result = sort_r1(r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector<int64> zero_base(rank, 0); + std::vector<int64> increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector<int64> limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data<NativeT>(); + + std::vector<NativeT> result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess<NativeT>(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span<const NativeT>(result_data)); + std::vector<int64> slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector<int64> start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } |