aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py16
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc11
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc169
4 files changed, 136 insertions, 67 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 4932819585..c423fa5004 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -69,16 +69,14 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
# fixed.
- if (self.device in ["XLA_GPU", "XLA_CPU"
- ]) and (dtype in [dtypes.bfloat16, dtypes.half]):
+ if (self.device in ["XLA_GPU", "XLA_CPU"]) and (dtype == dtypes.bfloat16):
continue
with self.cached_session() as sess:
with self.test_scope():
@@ -93,13 +91,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9355..f3861043b2 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index ead229aacc..bc44301d40 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,17 +20,6 @@ limitations under the License.
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e071d8..515267edd7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
+
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
- }
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
- SetToFirstInsertPoint(if_data.after_block, b_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ }
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,