aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h91
1 files changed, 77 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 2ae5f8bf36..f5e477e115 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -301,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
+ Status HandleImag(HloInstruction* imag) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
+ ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
+ return std::imag(elem_operand);
+ }));
+ return Status::OK();
+ }
+
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -604,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status HandleReal(HloInstruction* real) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
+ ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
+ return std::real(elem_operand);
+ }));
+ return Status::OK();
+ }
+
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
@@ -1399,25 +1415,48 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1)
- << "Sort is only supported for R1 shapes";
+ auto rank = ShapeUtil::Rank(keys->shape());
+ TF_RET_CHECK(rank > 0 && rank <= 2)
+ << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- auto result_literal = MakeUnique<Literal>(sort->shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
- parent_->evaluated_[sort] = std::move(result_literal);
+ auto sort_r1 = [this](const Literal& keys_literal) {
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ const auto& keys_data = keys_literal.data<ReturnT>();
+
+ std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const ReturnT& a, const ReturnT& b) {
+ return SafeLess<ReturnT>(a, b);
+ });
+ auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ result_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ return result_literal;
+ };
+
+ if (rank == 1) {
+ parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
+ } else {
+ // For R2 sort, the desired semantics are to sort each matrix row
+ // independently.
+ auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ int64 r1_length = keys->shape().dimensions(1);
+ for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
+ TF_ASSIGN_OR_RETURN(auto r1_slice,
+ keys_literal.Slice({row, 0}, {row + 1, r1_length})
+ ->Reshape({r1_length}));
+ auto r1_result = sort_r1(*r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ }
+ parent_->evaluated_[sort] = std::move(result_literal);
+ }
return Status::OK();
}
@@ -1958,6 +1997,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleReducePrecision<ElementwiseT>(reduce_precision);
}
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, int32>::value ||
+ std::is_same<NativeT, uint32>::value>::type* = nullptr>
+ Status HandleIota(HloInstruction* iota) {
+ auto result = MakeUnique<Literal>(iota->shape());
+ auto data = result->data<ReturnT>();
+ std::iota(data.begin(), data.end(), 0);
+ parent_->evaluated_[iota] = std::move(result);
+ return Status::OK();
+ }
+ template <typename NativeT,
+ typename std::enable_if<
+ !(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, int32>::value ||
+ std::is_same<NativeT, uint32>::value)>::type* = nullptr>
+ Status HandleIota(HloInstruction* iota) {
+ return InvalidArgument("Unsupported type for iota");
+ }
+ Status HandleIota(HloInstruction* iota) override {
+ return HandleIota<ReturnT>(iota);
+ }
+
private:
// Creates a vector of multipliers which can be used to create a linear index
// into shape.