aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h286
1 files changed, 154 insertions, 132 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8b08756c64..d5b4be7e12 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
@@ -34,6 +35,37 @@ using is_complex_t = std::is_same<T, complex64>;
template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
+// It's UB to use std::sort with std::less<float>, because of NaNs. Define
+// "safe" less functions which are actually strict weak orders.
+template <
+ typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ return a < b;
+}
+
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_floating_point<NativeT>::value ||
+ std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (std::isnan(b)) {
+ return !std::isnan(a);
+ } else {
+ return a < b;
+ }
+}
+
+template <typename NativeT, typename std::enable_if<std::is_same<
+ NativeT, Eigen::half>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (Eigen::half_impl::isnan(b)) {
+ return !Eigen::half_impl::isnan(a);
+ } else {
+ return a < b;
+ }
+}
+
// Templated DfsHloVisitor for use by HloEvaluator.
//
// Typically ReturnT here indicates the resulting literal type of each evaluated
@@ -269,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
+ Status HandleImag(HloInstruction* imag) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
+ ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
+ return std::imag(elem_operand);
+ }));
+ return Status::OK();
+ }
+
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -572,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status HandleReal(HloInstruction* real) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
+ ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
+ return std::real(elem_operand);
+ }));
+ return Status::OK();
+ }
+
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
@@ -1025,83 +1073,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
- std::vector<int64> lhs_non_contracting_dims;
+ DimensionVector lhs_index(lhs_rank);
+ DimensionVector rhs_index(rhs_rank);
+
+ // result_index_locations[i] contains one or two pointers to the locations
+ // in lhs_index or rhs_index where the i'th result index should go.
+ tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
+ result_index_locations;
+ result_index_locations.reserve(lhs_rank + rhs_rank - 2);
+
+ // The first components in the output shape are the LHS and RHS batch
+ // dimensions:
+ for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
+ result_index_locations.push_back(
+ {&lhs_index[dnums.lhs_batch_dimensions(i)],
+ &rhs_index[dnums.rhs_batch_dimensions(i)]});
+ }
+
+ // Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
- if (i != lhs_contracting_dimension) {
- lhs_non_contracting_dims.push_back(i);
+ if (i != lhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
-
- std::vector<int64> rhs_non_batch_non_contracting_dims;
- tensorflow::gtl::FlatSet<int64> batch_dims_set(
- dnums.rhs_batch_dimensions().begin(),
- dnums.rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs_rank; i++) {
- if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
- rhs_non_batch_non_contracting_dims.push_back(i);
+ if (i != rhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
- const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
- const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
-
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
auto result = MakeUnique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
- // Find the corresponding non-contracting indices for lhs and rhs.
- //
- // For `result_index`, its batch dimension, if exists, will be at the
- // same dimension as the batch dimension of lhs and rhs. More
- // specifically:
- // - For lhs, the non-contracting dimensions, including the batch
- // dimension have the same index as the `result_index`.
- // - For rhs, the batch dimension is set seperately from other
- // non-contracting dimensions, since these other non-contracting
- // dimensions in rhs follow the non-contracting dimensions of lhs in
- // the resulting index.
- //
- // As an example, for a resulting index:
- // result_index [result_batch, result_x, result_y]
- // the effecting lhs and rhs indices are:
- // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
- // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
- // `result_x` is only affected by the lhs_non_contracting_dim and
- // likewise `result_y` only depends on rhs_non_contracting_dim.
- //
- // so we can look up the lhs and rhs indices by:
- //
- // lhs:
- // batch index is the same as `result_batch`.
- // non-contracting dimension is the same as
- // result_index[lhs_non_contracting_dim]
- // rhs:
- // batch index: the same as `result_batch`.
- // non-contracting dimension index: *not* the same as
- // result_index[rhs_non_contractng_dim], since the
- // non-contracting dimensions of lhs are included in the
- // result_index first. Instead, the non_contracting_dim of rhs must
- // be calculated as following:
- // lhs_non_contracting_dimensions_size +
- // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
- //
- // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
- // the index offset to the result_index that only depends on
- // the non_batch and non-contracting dimensions of rhs. -1 at the
- // end translates size to index.
- for (auto i : lhs_non_contracting_dims) {
- lhs_index[i] = result_index[i];
- }
- for (auto i : dnums.rhs_batch_dimensions()) {
- rhs_index[i] = result_index[i];
- }
- for (auto i : rhs_non_batch_non_contracting_dims) {
- const int64 rhs_non_batch_non_contracting_dim =
- lhs_non_contracting_size + (i - batch_dim_size) - 1;
- rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
+ for (int64 i = 0; i < result_index.size(); i++) {
+ *result_index_locations[i].first = result_index[i];
+ if (result_index_locations[i].second) {
+ *result_index_locations[i].second = result_index[i];
+ }
}
// Accumulates resulting product along the contracted dimension.
@@ -1321,7 +1333,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(operand);
auto curr_val = arg_literal.Get<NativeT>(multi_index);
- auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val);
arg_literals.push_back(std::move(curr_val_literal));
}
@@ -1402,24 +1414,49 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!is_complex_t<NativeT>::value &&
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
- TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1)
- << "Sort is only supported for R1 shapes";
-
- auto arg = sort->operand(0);
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString();
- const auto& arg_data = arg_literal.data<ReturnT>();
+ auto keys = sort->operand(0);
+ auto rank = ShapeUtil::Rank(keys->shape());
+ TF_RET_CHECK(rank > 0 && rank <= 2)
+ << "Sort is only supported for R1 and R2 shapes";
+ TF_RET_CHECK(sort->operand_count() == 1)
+ << "Typed visitor does not support key-value sort";
+
+ const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
+
+ auto sort_r1 = [this](const Literal& keys_literal) {
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ const auto& keys_data = keys_literal.data<ReturnT>();
+
+ std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const ReturnT& a, const ReturnT& b) {
+ return SafeLess<ReturnT>(a, b);
+ });
+ auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ result_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ return result_literal;
+ };
- std::vector<ReturnT> return_data(arg_data.begin(), arg_data.end());
- std::sort(return_data.begin(), return_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- auto result_literal = MakeUnique<Literal>(sort->shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(return_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
- parent_->evaluated_[sort] = std::move(result_literal);
+ if (rank == 1) {
+ parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
+ } else {
+ // For R2 sort, the desired semantics are to sort each matrix row
+ // independently.
+ auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ int64 r1_length = keys->shape().dimensions(1);
+ for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
+ TF_ASSIGN_OR_RETURN(auto r1_slice,
+ keys_literal.Slice({row, 0}, {row + 1, r1_length})
+ ->Reshape({r1_length}));
+ auto r1_result = sort_r1(*r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ }
+ parent_->evaluated_[sort] = std::move(result_literal);
+ }
return Status::OK();
}
@@ -1507,8 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
- auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
+ auto result_val_literal =
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
@@ -1586,10 +1624,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid
// dynamic memory allocations.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto selected_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto source_literal_scatter = Literal::CreateR0<ReturnT>(ReturnT());
- auto scattered_literal = Literal::CreateR0<ReturnT>(ReturnT());
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
do {
// For each element in `source`, we place a window in `operand`. For each
// window placement, we iterate inside the window twice:
@@ -1710,9 +1748,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Evaluate computation with specified literal operands.
const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
+ LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
@@ -1757,7 +1795,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = Literal::CreateFromDimensions(
+ auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
@@ -1959,6 +1997,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleReducePrecision<ElementwiseT>(reduce_precision);
}
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, int32>::value ||
+ std::is_same<NativeT, uint32>::value>::type* = nullptr>
+ Status HandleIota(HloInstruction* iota) {
+ auto result = MakeUnique<Literal>(iota->shape());
+ auto data = result->data<ReturnT>();
+ std::iota(data.begin(), data.end(), 0);
+ parent_->evaluated_[iota] = std::move(result);
+ return Status::OK();
+ }
+ template <typename NativeT,
+ typename std::enable_if<
+ !(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, int32>::value ||
+ std::is_same<NativeT, uint32>::value)>::type* = nullptr>
+ Status HandleIota(HloInstruction* iota) {
+ return InvalidArgument("Unsupported type for iota");
+ }
+ Status HandleIota(HloInstruction* iota) override {
+ return HandleIota<ReturnT>(iota);
+ }
+
private:
// Creates a vector of multipliers which can be used to create a linear index
// into shape.
@@ -2016,10 +2078,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
start_indices_typed.end());
// Clamp the start indices so the slice is in-bounds w.r.t the operand.
-
- // TODO(b/74360564): This is implementation defined behavior, but is
- // currently respected by all implementations. Change this if we ever decide
- // to officially document different behavior.
for (int64 i = 0; i < start.size(); ++i) {
start[i] = std::min<int64>(
std::max(int64{0}, start[i]),
@@ -2053,10 +2111,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
// operand.
-
- // TODO(b/74360564): This is implementation defined behavior, but is
- // currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
@@ -2175,38 +2229,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return rhs_unsigned >= lhs_size_unsigned;
}
- // It's UB to use std::sort with std::less<float>, because of NaNs. Define
- // "safe" less functions which are actually strict weak orders.
- template <typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value>::type* =
- nullptr>
- static bool SafeLess(const NativeT& a, const NativeT& b) {
- return a < b;
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
- static bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
- }
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
- static bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
- }
-
HloEvaluator* parent_;
};