diff options
author | 2018-09-24 03:19:11 -0700 | |
---|---|---|
committer | 2018-09-24 03:23:55 -0700 | |
commit | 379ca4afe9e31f550cd04451af04150b6bbecf78 (patch) | |
tree | fa3a9097489fd510d70aac020c02b4b8f16ea916 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | |
parent | b57bdf414edb27b82a95c5f4e2729fafd4cf2dc7 (diff) |
Generalize sort implementation in the HloEvaluator.
It only worked for ranks 1 or 2, and only if the dimension to sort is the most minor dimension.
Also fix the SafeLess function so that the SortExtremeValues() test passes.
PiperOrigin-RevId: 214239560
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 123 |
1 files changed, 71 insertions, 52 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8fb17a0033..35391ecf8a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include <cmath> + #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" @@ -41,7 +43,9 @@ template <typename T> using is_complex64_t = std::is_same<T, complex64>; // It's UB to use std::sort with std::less<float>, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr> @@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template <typename NativeT, - typename std::enable_if< - std::is_floating_point<NativeT>::value || - std::is_same<NativeT, bfloat16>::value>::type* = nullptr> +template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template <typename NativeT, typename std::enable_if<std::is_same< - NativeT, Eigen::half>::value>::type* = nullptr> +template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, bfloat16>::value || + std::is_same<NativeT, Eigen::half>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast<float>(a), static_cast<float>(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -1527,47 +1538,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data<ReturnT>(); - - std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - Literal result_literal(keys_literal.shape()); - result_literal.PopulateR1(absl::Span<const ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal result_literal(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result = sort_r1(r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector<int64> zero_base(rank, 0); + std::vector<int64> increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector<int64> limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data<NativeT>(); + + std::vector<NativeT> result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess<NativeT>(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span<const NativeT>(result_data)); + std::vector<int64> slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector<int64> start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } |