aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-02-24 11:50:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-24 12:08:47 -0800
commit8120e2a270c28e0a62b9f522164b196a90f113b7 (patch)
tree3d95683920536d58e00d5e22b7118900d49cf56a
parent11ba79eee215f7f5831d60101df24770acea3b5f (diff)
[XLA] Add an IsFinite operation that tests elementwise whether values are finite (i.e., not NaN or Inf).
Change: 148485205
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc5
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h4
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc10
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc2
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc45
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
-rw-r--r--tensorflow/g3doc/experimental/xla/operation_semantics.md5
14 files changed, 106 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 73f450e1f2..c4c91b7ea8 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -950,6 +950,11 @@ ComputationDataHandle ComputationBuilder::Tanh(
return UnaryOp(UNOP_TANH, operand);
}
+ComputationDataHandle ComputationBuilder::IsFinite(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_IS_FINITE, operand);
+}
+
ComputationDataHandle ComputationBuilder::Transpose(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 67ca9c6cf7..b1a68e3687 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -507,6 +507,12 @@ class ComputationBuilder {
ComputationDataHandle Pow(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
+ // Enqueues an operator that tests if the operand's values are finite, i.e.,
+ // not Inf or NaN. Defined only for floating-point types. Returns an array of
+ // booleans with the same shape where entries are true iff the corresponding
+ // entry was NaN.
+ ComputationDataHandle IsFinite(const ComputationDataHandle& operand);
+
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index f9c9bbe2cd..351efa82dd 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -150,6 +150,10 @@ class DfsHloVisitor {
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand);
}
+ virtual Status HandleIsFinite(HloInstruction* is_finite,
+ HloInstruction* operand) {
+ return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand);
+ }
virtual Status HandleLogicalAnd(HloInstruction* logical_and,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs,
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index a4b50836d7..96342451fd 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -195,6 +195,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
}
+ case HloOpcode::kIsFinite: {
+ // (x == x) && abs(x) != inf
+ auto type = operand_value->getType();
+ auto equal_self =
+ ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
+ auto abs_value = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
+ auto infinity = llvm::ConstantFP::getInfinity(type);
+ auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
+ auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
+ return ir_builder_->CreateZExt(
+ result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
+ }
case HloOpcode::kNegate:
return ir_builder_->CreateFNeg(operand_value);
default:
@@ -632,6 +645,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kCopy:
case HloOpcode::kExp:
case HloOpcode::kFloor:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNegate:
case HloOpcode::kSign:
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 631e784755..d7d8722ccc 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -202,6 +202,7 @@ string InstructionSequenceGraph(
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kIndex:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLogicalAnd:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1d31027cff..9b49ac1a60 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -117,6 +117,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kExp:
case HloOpcode::kFloor:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLogicalNot:
case HloOpcode::kNegate:
@@ -733,6 +734,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kCeil:
case HloOpcode::kCopy:
case HloOpcode::kExp:
+ case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLogicalNot:
@@ -1033,6 +1035,7 @@ bool HloInstruction::Identical(
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLogicalAnd:
@@ -1673,6 +1676,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleLog(this, operands_[0]);
case HloOpcode::kTanh:
return visitor->HandleTanh(this, operands_[0]);
+ case HloOpcode::kIsFinite:
+ return visitor->HandleIsFinite(this, operands_[0]);
case HloOpcode::kLogicalNot:
return visitor->HandleLogicalNot(this, operands_[0]);
case HloOpcode::kBitcast:
@@ -1876,6 +1881,7 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kCopy:
case HloOpcode::kExp:
case HloOpcode::kFloor:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLogicalNot:
case HloOpcode::kNegate:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 5f7243b0fe..616b239a93 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -74,6 +74,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "index";
case HloOpcode::kInfeed:
return "infeed";
+ case HloOpcode::kIsFinite:
+ return "is-finite";
case HloOpcode::kLe:
return "less-than-or-equal-to";
case HloOpcode::kLog:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 5d60a77e14..978ed5e79b 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -55,6 +55,7 @@ enum class HloOpcode {
kGt,
kIndex,
kInfeed,
+ kIsFinite,
kLe,
kLog,
kLogicalAnd,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 42e33d5396..34a6bb8a52 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -50,7 +50,7 @@ bool IsExpensive(const HloInstruction& instruction) {
case HloOpcode::kGetTupleElement:
case HloOpcode::kGt:
case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
+ case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLogicalAnd:
case HloOpcode::kLogicalNot:
@@ -61,6 +61,7 @@ bool IsExpensive(const HloInstruction& instruction) {
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
+ case HloOpcode::kOutfeed:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index fbab2dfd4a..c05cf8c37d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -208,6 +208,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
PrimitiveType_Name(arg.element_type()).c_str());
}
return arg;
+
+ case UNOP_IS_FINITE:
+ if (!ShapeUtil::ElementIsFloating(arg)) {
+ return InvalidArgument(
+ "expected element type in shape to be floating point for IsFinite "
+ "operation; got %s",
+ PrimitiveType_Name(arg.element_type()).c_str());
+ }
+ return ShapeUtil::ChangeElementType(arg, PRED);
+
default:
return InvalidArgument("unknown operation %s",
UnaryOperation_Name(operation).c_str());
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 0ff5d9ffc3..7fde1945a5 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -50,6 +50,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
return HloOpcode::kExp;
case UNOP_FLOOR:
return HloOpcode::kFloor;
+ case UNOP_IS_FINITE:
+ return HloOpcode::kIsFinite;
case UNOP_LOG:
return HloOpcode::kLog;
case UNOP_LOGICAL_NOT:
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 23579088c9..d18511a6b4 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -82,6 +83,50 @@ TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
{});
}
+XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto result = builder.IsFinite(a);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+// A non-canonical quiet NaN value.
+static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234);
+
+XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto result = builder.IsFinite(builder.ConstantR0<float>(NAN));
+ ComputeAndCompareR0<bool>(&builder, false, {});
+
+ EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
+ auto result_non_canonical =
+ builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN));
+ ComputeAndCompareR0<bool>(&builder, false, {});
+
+ const float inf = std::numeric_limits<float>::infinity();
+ auto result_inf = builder.IsFinite(builder.ConstantR0<float>(inf));
+ ComputeAndCompareR0<bool>(&builder, false, {});
+
+ auto result_neg_inf = builder.IsFinite(builder.ConstantR0<float>(-inf));
+ ComputeAndCompareR0<bool>(&builder, false, {});
+
+ auto result_zero = builder.IsFinite(builder.ConstantR0<float>(0.0f));
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
+ ComputationBuilder builder(client_, TestName());
+ const float inf = std::numeric_limits<float>::infinity();
+ EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
+ auto a = builder.ConstantR1<float>(
+ {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
+ auto result = builder.IsFinite(a);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
+ {});
+}
+
TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 10e14b4344..99a9ba3ee0 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -555,6 +555,9 @@ enum UnaryOperation {
// Elementwise, computes the sign of x.
UNOP_SIGN = 10;
+
+ // Elementwise, tests if values are finite (not NaN or inf)
+ UNOP_IS_FINITE = 11;
}
message UnaryOpRequest {
diff --git a/tensorflow/g3doc/experimental/xla/operation_semantics.md b/tensorflow/g3doc/experimental/xla/operation_semantics.md
index 4808b919b4..5fdacd42db 100644
--- a/tensorflow/g3doc/experimental/xla/operation_semantics.md
+++ b/tensorflow/g3doc/experimental/xla/operation_semantics.md
@@ -552,6 +552,11 @@ ComputationBuilder supports these element-wise unary functions:
<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
+<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
+i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
+of `PRED` values with the same shape as the input, where each element is `true`
+if and only if the corresponding input element is finite.
+
<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.