diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 286 |
1 files changed, 154 insertions, 132 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8b08756c64..d5b4be7e12 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -34,6 +35,37 @@ using is_complex_t = std::is_same<T, complex64>; template <typename T> using is_complex64_t = std::is_same<T, complex64>; +// It's UB to use std::sort with std::less<float>, because of NaNs. Define +// "safe" less functions which are actually strict weak orders. +template < + typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; +} + +template <typename NativeT, + typename std::enable_if< + std::is_floating_point<NativeT>::value || + std::is_same<NativeT, bfloat16>::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } +} + +template <typename NativeT, typename std::enable_if<std::is_same< + NativeT, Eigen::half>::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -269,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor<ReturnT>(floor); } + Status HandleImag(HloInstruction* imag) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], + ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { + return std::imag(elem_operand); + })); + return Status::OK(); + } + Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -572,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleReal(HloInstruction* real) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], + ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { + return std::real(elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> @@ -1025,83 +1073,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector<int64> lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector<int64> rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet<int64> batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); auto result = MakeUnique<Literal>(dot->shape()); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> result_index) { ElementwiseT result_val = static_cast<ElementwiseT>(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1321,7 +1333,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand); auto curr_val = arg_literal.Get<NativeT>(multi_index); - auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val); + auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val); arg_literals.push_back(std::move(curr_val_literal)); } @@ -1402,24 +1414,49 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !is_complex_t<NativeT>::value && !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1) - << "Sort is only supported for R1 shapes"; - - auto arg = sort->operand(0); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString(); - const auto& arg_data = arg_literal.data<ReturnT>(); + auto keys = sort->operand(0); + auto rank = ShapeUtil::Rank(keys->shape()); + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for R1 and R2 shapes"; + TF_RET_CHECK(sort->operand_count() == 1) + << "Typed visitor does not support key-value sort"; + + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + + auto sort_r1 = [this](const Literal& keys_literal) { + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data<ReturnT>(); + + std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess<ReturnT>(a, b); + }); + auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice<ReturnT>(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + return result_literal; + }; - std::vector<ReturnT> return_data(arg_data.begin(), arg_data.end()); - std::sort(return_data.begin(), return_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - auto result_literal = MakeUnique<Literal>(sort->shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice<ReturnT>(return_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - parent_->evaluated_[sort] = std::move(result_literal); + if (rank == 1) { + parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + int64 r1_length = keys->shape().dimensions(1); + for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result = sort_r1(*r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + } + parent_->evaluated_[sort] = std::move(result_literal); + } return Status::OK(); } @@ -1507,8 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto curr_val = arg_literal.Get<ReturnT>(input_index); // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); - auto result_val_literal = Literal::CreateR0<ReturnT>(result_val); + auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val); + auto result_val_literal = + LiteralUtil::CreateR0<ReturnT>(result_val); std::unique_ptr<Literal> computed_result = embedded_evaluator @@ -1586,10 +1624,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid // dynamic memory allocations. - auto curr_val_literal = Literal::CreateR0<ReturnT>(ReturnT()); - auto selected_val_literal = Literal::CreateR0<ReturnT>(ReturnT()); - auto source_literal_scatter = Literal::CreateR0<ReturnT>(ReturnT()); - auto scattered_literal = Literal::CreateR0<ReturnT>(ReturnT()); + auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); + auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); + auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT()); + auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1710,9 +1748,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. const auto curr_val_literal = - Literal::CreateR0<ReturnT>(curr_val); + LiteralUtil::CreateR0<ReturnT>(curr_val); const auto result_val_literal = - Literal::CreateR0<ReturnT>(result_val); + LiteralUtil::CreateR0<ReturnT>(result_val); std::unique_ptr<Literal> computed_result = embedded_evaluator .Evaluate<const Literal*>( @@ -1757,7 +1795,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get<ReturnT>(operand_index); }; - auto result = Literal::CreateFromDimensions( + auto result = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); parent_->evaluated_[slice] = std::move(result); @@ -1959,6 +1997,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision<ElementwiseT>(reduce_precision); } + template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, float>::value || + std::is_same<NativeT, int32>::value || + std::is_same<NativeT, uint32>::value>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + auto result = MakeUnique<Literal>(iota->shape()); + auto data = result->data<ReturnT>(); + std::iota(data.begin(), data.end(), 0); + parent_->evaluated_[iota] = std::move(result); + return Status::OK(); + } + template <typename NativeT, + typename std::enable_if< + !(std::is_same<NativeT, float>::value || + std::is_same<NativeT, int32>::value || + std::is_same<NativeT, uint32>::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return InvalidArgument("Unsupported type for iota"); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota<ReturnT>(iota); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. @@ -2016,10 +2078,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the start indices so the slice is in-bounds w.r.t the operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min<int64>( std::max(int64{0}, start[i]), @@ -2053,10 +2111,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { start[i] = std::min<int64>( std::max<int64>(0, start[i]), @@ -2175,38 +2229,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return rhs_unsigned >= lhs_size_unsigned; } - // It's UB to use std::sort with std::less<float>, because of NaNs. Define - // "safe" less functions which are actually strict weak orders. - template <typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value>::type* = - nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - return a < b; - } - - template <typename NativeT, - typename std::enable_if< - std::is_floating_point<NativeT>::value || - std::is_same<NativeT, bfloat16>::value>::type* = nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; - } - } - - template <typename NativeT, - typename std::enable_if< - std::is_same<NativeT, Eigen::half>::value>::type* = nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } - } - HloEvaluator* parent_; }; |