aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-19 10:20:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 10:27:57 -0700
commit414ca1cda5aec72b48d5da127f61b0d05fbdc22c (patch)
treeadf0ee8fe6b788c91609022accf92dfb432bb95d
parent0800a645b85fc9d7c18efe45d1006cf35fba93dd (diff)
[XLA:CPU] Add an emitter for erfinv(double) and erfinv(half).
This is used by the random number generator. Same algorithm as for float, just with more precision. fp16 is upcasted to fp32 and then processed with the float algorithm. PiperOrigin-RevId: 213648736
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py16
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc11
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc169
4 files changed, 136 insertions, 67 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 4932819585..c423fa5004 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -69,16 +69,14 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
# fixed.
- if (self.device in ["XLA_GPU", "XLA_CPU"
- ]) and (dtype in [dtypes.bfloat16, dtypes.half]):
+ if (self.device in ["XLA_GPU", "XLA_CPU"]) and (dtype == dtypes.bfloat16):
continue
with self.cached_session() as sess:
with self.test_scope():
@@ -93,13 +91,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9355..f3861043b2 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index ead229aacc..bc44301d40 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,17 +20,6 @@ limitations under the License.
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e071d8..515267edd7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
+
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
- }
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
- SetToFirstInsertPoint(if_data.after_block, b_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ }
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,