diff options
author | 2018-02-26 16:24:54 -0800 | |
---|---|---|
committer | 2018-02-26 16:31:22 -0800 | |
commit | 29bc0d92967d8853c872ba7f736462f1ea2fbd81 (patch) | |
tree | fadc11e47141abb1e09203ac21c5e16831d8c885 | |
parent | 511cf67f2327e9186124a92c9469dc60fd64a6a2 (diff) |
[XLA] In HloEvaluator, fix an issue for HandleAbs to handle complex numbers
more correctly:
- abs([complex numbers]) would yield floats. However since the specilization for
HandleAbs is based on the return type (float), we'd CHECK fail due to float !=
complex when accessing the elements of the operand (complex).
- enable unary_op_test for interpreter.
PiperOrigin-RevId: 187099576
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 1 |
2 files changed, 30 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fd06b19144..cf8b35908f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -57,6 +57,12 @@ struct is_complex_t : public std::false_type {}; template <> struct is_complex_t<complex64> : public std::true_type {}; +template <typename T> +struct is_complex64_t : public std::false_type {}; + +template <> +struct is_complex64_t<complex64> : public std::true_type {}; + template <typename OperandT> StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -248,17 +254,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if<std::is_signed<NativeT>::value || - is_complex_t<NativeT>::value>::type* = nullptr> + typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } + template < + typename NativeT, + typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (ElementWiseUnaryOpImpl<float, NativeT>( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs<complex64>(abs); + } return HandleAbs<ElementwiseT>(abs); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 33fde9737d..f3ecfc1604 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -494,6 +494,7 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", |