aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-02-26 16:24:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 16:31:22 -0800
commit29bc0d92967d8853c872ba7f736462f1ea2fbd81 (patch)
treefadc11e47141abb1e09203ac21c5e16831d8c885
parent511cf67f2327e9186124a92c9469dc60fd64a6a2 (diff)
[XLA] In HloEvaluator, fix an issue for HandleAbs to handle complex numbers
more correctly: - abs([complex numbers]) would yield floats. However since the specilization for HandleAbs is based on the return type (float), we'd CHECK fail due to float != complex when accessing the elements of the operand (complex). - enable unary_op_test for interpreter. PiperOrigin-RevId: 187099576
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc32
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
2 files changed, 30 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index fd06b19144..cf8b35908f 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -57,6 +57,12 @@ struct is_complex_t : public std::false_type {};
template <>
struct is_complex_t<complex64> : public std::true_type {};
+template <typename T>
+struct is_complex64_t : public std::false_type {};
+
+template <>
+struct is_complex64_t<complex64> : public std::true_type {};
+
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
const Literal& lhs_literal,
@@ -248,17 +254,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
template <
typename NativeT,
- typename std::enable_if<std::is_signed<NativeT>::value ||
- is_complex_t<NativeT>::value>::type* = nullptr>
+ typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
Status HandleAbs(HloInstruction* abs) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
- ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) {
+ ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
return std::abs(elem_operand);
}));
return Status::OK();
}
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
+ Status HandleAbs(HloInstruction* abs) {
+ const Literal& operand_literal =
+ parent_->GetEvaluatedLiteralFor(abs->operand(0));
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[abs],
+ (ElementWiseUnaryOpImpl<float, NativeT>(
+ abs, [](NativeT elem_operand) { return std::abs(elem_operand); },
+ operand_literal)));
+
+ return Status::OK();
+ }
+
Status HandleAbs(HloInstruction* abs) override {
+ // If the operand is of C64 type, the return type of abs will be F32.
+ // However, ElementwiseT would still be the return type, F32, and thus
+ // specifying the ElementwiseT explicitly as C64 is needed below.
+ if (abs->operand(0)->shape().element_type() == C64) {
+ return HandleAbs<complex64>(abs);
+ }
return HandleAbs<ElementwiseT>(abs);
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 33fde9737d..f3ecfc1604 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -494,6 +494,7 @@ xla_test(
xla_test(
name = "unary_op_test",
srcs = ["unary_op_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",