diff options
author | Kay Zhu <kayzhu@google.com> | 2018-09-25 20:35:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 20:39:41 -0700 |
commit | 6666516f390f125ed70ddbd4e6f89b83d953c408 (patch) | |
tree | f63ebb3fd2ce283c8fa13e9a29f38e2a94b367dd /tensorflow/compiler/xla/service/hlo_evaluator.cc | |
parent | 7f1d70d97f543d69a9f02cd6df0964f22f9278f3 (diff) |
[XLA] In HloEvaluator, fix an issue where the return type and native type are assumed to be the same for HandleImag and HandleReal, when in fact they should be float and complex64 (or float for HandleReal's case), respectively.
PiperOrigin-RevId: 214548051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index b91b2406e2..d7c39b2778 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { return Status::OK(); } +Status HloEvaluator::HandleReal(HloInstruction* real) { + auto operand = real->operand(0); + switch (operand->shape().element_type()) { + case BF16: { + auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>( + real, [](bfloat16 elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case C64: { + auto result_or = ElementWiseUnaryOpImpl<float, complex64>( + real, [](complex64 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F16: { + auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>( + real, [](Eigen::half elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F32: { + auto result_or = ElementWiseUnaryOpImpl<float, float>( + real, [](float elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl<double, double>( + real, [](double elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleImag(HloInstruction* imag) { + auto result_or = ElementWiseUnaryOpImpl<float, complex64>( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + Status HloEvaluator::HandleCompare(HloInstruction* compare) { HloOpcode opcode = compare->opcode(); auto lhs = compare->operand(0); |