aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.cc
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-09-25 20:35:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 20:39:41 -0700
commit6666516f390f125ed70ddbd4e6f89b83d953c408 (patch)
treef63ebb3fd2ce283c8fa13e9a29f38e2a94b367dd /tensorflow/compiler/xla/service/hlo_evaluator.cc
parent7f1d70d97f543d69a9f02cd6df0964f22f9278f3 (diff)
[XLA] In HloEvaluator, fix an issue where the return type and native type are assumed to be the same for HandleImag and HandleReal, when in fact they should be float and complex64 (or float for HandleReal's case), respectively.
PiperOrigin-RevId: 214548051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc55
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index b91b2406e2..d7c39b2778 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
return Status::OK();
}
+Status HloEvaluator::HandleReal(HloInstruction* real) {
+ auto operand = real->operand(0);
+ switch (operand->shape().element_type()) {
+ case BF16: {
+ auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
+ real, [](bfloat16 elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case C64: {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ real, [](complex64 elem_operand) { return std::real(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F16: {
+ auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
+ real, [](Eigen::half elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<float, float>(
+ real, [](float elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<double, double>(
+ real, [](double elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleImag(HloInstruction* imag) {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
+ GetEvaluatedLiteralFor(imag->operand(0)));
+
+ TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
+ return Status::OK();
+}
+
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
auto lhs = compare->operand(0);