diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 91 |
1 files changed, 77 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 2ae5f8bf36..f5e477e115 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -301,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor<ReturnT>(floor); } + Status HandleImag(HloInstruction* imag) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], + ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { + return std::imag(elem_operand); + })); + return Status::OK(); + } + Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -604,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleReal(HloInstruction* real) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], + ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { + return std::real(elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> @@ -1399,25 +1415,48 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1) - << "Sort is only supported for R1 shapes"; + auto rank = ShapeUtil::Rank(keys->shape()); + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data<ReturnT>(); - std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - auto result_literal = MakeUnique<Literal>(sort->shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice<ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - parent_->evaluated_[sort] = std::move(result_literal); + auto sort_r1 = [this](const Literal& keys_literal) { + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data<ReturnT>(); + + std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess<ReturnT>(a, b); + }); + auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice<ReturnT>(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + return result_literal; + }; + + if (rank == 1) { + parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + int64 r1_length = keys->shape().dimensions(1); + for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result = sort_r1(*r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + } + parent_->evaluated_[sort] = std::move(result_literal); + } return Status::OK(); } @@ -1958,6 +1997,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision<ElementwiseT>(reduce_precision); } + template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, float>::value || + std::is_same<NativeT, int32>::value || + std::is_same<NativeT, uint32>::value>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + auto result = MakeUnique<Literal>(iota->shape()); + auto data = result->data<ReturnT>(); + std::iota(data.begin(), data.end(), 0); + parent_->evaluated_[iota] = std::move(result); + return Status::OK(); + } + template <typename NativeT, + typename std::enable_if< + !(std::is_same<NativeT, float>::value || + std::is_same<NativeT, int32>::value || + std::is_same<NativeT, uint32>::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return InvalidArgument("Unsupported type for iota"); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota<ReturnT>(iota); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. |