aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-24 03:19:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 03:23:55 -0700
commit379ca4afe9e31f550cd04451af04150b6bbecf78 (patch)
treefa3a9097489fd510d70aac020c02b4b8f16ea916 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
parentb57bdf414edb27b82a95c5f4e2729fafd4cf2dc7 (diff)
Generalize sort implementation in the HloEvaluator.
It only worked for ranks 1 or 2, and only if the dimension to sort is the most minor dimension. Also fix the SafeLess function so that the SortExtremeValues() test passes. PiperOrigin-RevId: 214239560
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h123
1 files changed, 71 insertions, 52 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8fb17a0033..35391ecf8a 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -1527,47 +1538,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}