diff options
author | Peter Hawkins <phawkins@google.com> | 2017-02-24 11:50:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-24 12:08:47 -0800 |
commit | 8120e2a270c28e0a62b9f522164b196a90f113b7 (patch) | |
tree | 3d95683920536d58e00d5e22b7118900d49cf56a | |
parent | 11ba79eee215f7f5831d60101df24770acea3b5f (diff) |
[XLA] Add an IsFinite operation that tests elementwise whether values are finite (i.e., not NaN or Inf).
Change: 148485205
14 files changed, 106 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 73f450e1f2..c4c91b7ea8 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -950,6 +950,11 @@ ComputationDataHandle ComputationBuilder::Tanh( return UnaryOp(UNOP_TANH, operand); } +ComputationDataHandle ComputationBuilder::IsFinite( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_IS_FINITE, operand); +} + ComputationDataHandle ComputationBuilder::Transpose( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<int64> permutation) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 67ca9c6cf7..b1a68e3687 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -507,6 +507,12 @@ class ComputationBuilder { ComputationDataHandle Pow(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + ComputationDataHandle IsFinite(const ComputationDataHandle& operand); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index f9c9bbe2cd..351efa82dd 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -150,6 +150,10 @@ class DfsHloVisitor { virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); } + virtual Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a4b50836d7..96342451fd 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -195,6 +195,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } + case HloOpcode::kIsFinite: { + // (x == x) && abs(x) != inf + auto type = operand_value->getType(); + auto equal_self = + ir_builder_->CreateFCmpOEQ(operand_value, operand_value); + auto abs_value = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); + auto infinity = llvm::ConstantFP::getInfinity(type); + auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); + auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); + return ir_builder_->CreateZExt( + result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); default: @@ -632,6 +645,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 631e784755..d7d8722ccc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -202,6 +202,7 @@ string InstructionSequenceGraph( case HloOpcode::kGe: case HloOpcode::kGt: case HloOpcode::kIndex: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLogicalAnd: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1d31027cff..9b49ac1a60 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -117,6 +117,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: @@ -733,6 +734,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kCeil: case HloOpcode::kCopy: case HloOpcode::kExp: + case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: case HloOpcode::kLogicalNot: @@ -1033,6 +1035,7 @@ bool HloInstruction::Identical( case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLogicalAnd: @@ -1673,6 +1676,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kIsFinite: + return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: return visitor->HandleLogicalNot(this, operands_[0]); case HloOpcode::kBitcast: @@ -1876,6 +1881,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 5f7243b0fe..616b239a93 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -74,6 +74,8 @@ string HloOpcodeString(HloOpcode opcode) { return "index"; case HloOpcode::kInfeed: return "infeed"; + case HloOpcode::kIsFinite: + return "is-finite"; case HloOpcode::kLe: return "less-than-or-equal-to"; case HloOpcode::kLog: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 5d60a77e14..978ed5e79b 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -55,6 +55,7 @@ enum class HloOpcode { kGt, kIndex, kInfeed, + kIsFinite, kLe, kLog, kLogicalAnd, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 42e33d5396..34a6bb8a52 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -50,7 +50,7 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kGetTupleElement: case HloOpcode::kGt: case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLogicalAnd: case HloOpcode::kLogicalNot: @@ -61,6 +61,7 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index fbab2dfd4a..c05cf8c37d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -208,6 +208,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg.element_type()).c_str()); } return arg; + + case UNOP_IS_FINITE: + if (!ShapeUtil::ElementIsFloating(arg)) { + return InvalidArgument( + "expected element type in shape to be floating point for IsFinite " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return ShapeUtil::ChangeElementType(arg, PRED); + default: return InvalidArgument("unknown operation %s", UnaryOperation_Name(operation).c_str()); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 0ff5d9ffc3..7fde1945a5 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -50,6 +50,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kExp; case UNOP_FLOOR: return HloOpcode::kFloor; + case UNOP_IS_FINITE: + return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; case UNOP_LOGICAL_NOT: diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 23579088c9..d18511a6b4 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -82,6 +83,50 @@ TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto result = builder.IsFinite(a); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +// A non-canonical quiet NaN value. +static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234); + +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { + ComputationBuilder builder(client_, TestName()); + auto result = builder.IsFinite(builder.ConstantR0<float>(NAN)); + ComputeAndCompareR0<bool>(&builder, false, {}); + + EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); + auto result_non_canonical = + builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN)); + ComputeAndCompareR0<bool>(&builder, false, {}); + + const float inf = std::numeric_limits<float>::infinity(); + auto result_inf = builder.IsFinite(builder.ConstantR0<float>(inf)); + ComputeAndCompareR0<bool>(&builder, false, {}); + + auto result_neg_inf = builder.IsFinite(builder.ConstantR0<float>(-inf)); + ComputeAndCompareR0<bool>(&builder, false, {}); + + auto result_zero = builder.IsFinite(builder.ConstantR0<float>(0.0f)); + ComputeAndCompareR0<bool>(&builder, true, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { + ComputationBuilder builder(client_, TestName()); + const float inf = std::numeric_limits<float>::infinity(); + EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); + auto a = builder.ConstantR1<float>( + {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); + auto result = builder.IsFinite(a); + + ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false}, + {}); +} + TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 10e14b4344..99a9ba3ee0 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -555,6 +555,9 @@ enum UnaryOperation { // Elementwise, computes the sign of x. UNOP_SIGN = 10; + + // Elementwise, tests if values are finite (not NaN or inf) + UNOP_IS_FINITE = 11; } message UnaryOpRequest { diff --git a/tensorflow/g3doc/experimental/xla/operation_semantics.md b/tensorflow/g3doc/experimental/xla/operation_semantics.md index 4808b919b4..5fdacd42db 100644 --- a/tensorflow/g3doc/experimental/xla/operation_semantics.md +++ b/tensorflow/g3doc/experimental/xla/operation_semantics.md @@ -552,6 +552,11 @@ ComputationBuilder supports these element-wise unary functions: <b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`. +<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite, +i.e., is not positive or negative infinity, and is not `NaN`. Returns an array +of `PRED` values with the same shape as the input, where each element is `true` +if and only if the corresponding input element is finite. + <b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`. <b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`. |