diff options
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 16 | ||||
-rw-r--r-- | tensorflow/compiler/tests/stateless_random_ops_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_cpu_backend.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 169 |
4 files changed, 136 insertions, 67 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 4932819585..c423fa5004 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -69,16 +69,14 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): for dtype in self._random_types(): # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is # fixed. - if (self.device in ["XLA_GPU", "XLA_CPU" - ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + if (self.device in ["XLA_GPU", "XLA_CPU"]) and (dtype == dtypes.bfloat16): continue with self.cached_session() as sess: with self.test_scope(): @@ -93,13 +91,13 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - self._testRngIsNotConstant(rng, dtypes.float32) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testTruncatedNormalIsInRange(self): count = 10000000 - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + # TODO(b/34339814): make this test work with 16 bit float types. + for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 1bea7d9355..f3861043b2 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): - return [dtypes.float32] + return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) @@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 @@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index ead229aacc..bc44301d40 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -20,17 +20,6 @@ limitations under the License. namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { - // TODO(b/34339814): implement inverse erf for double types and remove this - // workaround. - if (kdef->op() == "RandomStandardNormal") { - kdef->clear_constraint(); - // Change the type constraint to permit only DTD_FLOAT. - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name("dtype"); - attr_constraint->mutable_allowed_values()->mutable_list()->add_type( - DT_FLOAT); - return true; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4bb1e071d8..515267edd7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { - if (prim_type != F32) { - // TODO(b/34339814): Implement inverse erf for F64. + if (prim_type != F16 && prim_type != F32 && prim_type != F64) { return Unimplemented( "Inverse erf is only implemented for element " - "type F32."); + "types F16, F32 and F64."); } - auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(b_->getFloatTy(), f); + + // Upcast half to float. + if (prim_type == F16) { + x = b_->CreateFPExt(x, b_->getFloatTy()); + } + + auto get_float = [&](const double f) { + return llvm::ConstantFP::get(x->getType(), f); }; - auto multiply_add = [&](absl::Span<const float> coefficients, + auto multiply_add = [&](absl::Span<const double> coefficients, llvm::Value* w) { - llvm::Value* p = getFloat(coefficients.front()); + llvm::Value* p = get_float(coefficients.front()); coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = FAdd(FMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), get_float(coefficient)); } return p; }; // Approximation for inverse error function from // Giles, M., "Approximating the erfinv function". - // The approximation has the form: - // w = log((1-x)*(1+x)) + // The approximation has the form (float version): + // w = -log((1-x)*(1+x)) // if ( w < 5 ) { // w = w - 2.5 // p = sum_{i=1}^n lq[i]*w^i @@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {b_->getFloatTy()}); + module_, llvm::Intrinsic::log, {x->getType()}); - llvm::Value* w = FNeg( - Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg(Call( + logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))})); llvm::Value* p_addr = - llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); + llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_); + + if (prim_type == F16 || prim_type == F32) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(2.5f)); + absl::Span<const double> lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + Store(p, p_addr); + } - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); - // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, b_); - { - llvm::Value* lw = FSub(w, getFloat(2.5f)); - absl::Span<const float> lq{ - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - llvm::Value* p = multiply_add(lq, lw); - Store(p, p_addr); - } + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f)); + absl::Span<const double> gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + Store(p, p_addr); + } - // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, b_); - { - llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - - llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); - absl::Span<const float> gq{ - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - llvm::Value* p = multiply_add(gq, gw); - Store(p, p_addr); - } + SetToFirstInsertPoint(if_data.after_block, b_); + } else { + DCHECK(prim_type == F64); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_); + + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(3.125)); + absl::Span<const double> c{ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + llvm::Value* p = multiply_add(c, lw); + Store(p, p_addr); + } - SetToFirstInsertPoint(if_data.after_block, b_); + SetToFirstInsertPoint(if_data.false_block, b_); + llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_); + SetToFirstInsertPoint(if_data_second.true_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25)); + absl::Span<const double> t1{ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635}; + llvm::Value* p = multiply_add(t1, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data_second.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0)); + absl::Span<const double> t2{ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221}; + llvm::Value* p = multiply_add(t2, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, b_); + } llvm::Value* p = Load(p_addr); - return FMul(p, x); + x = FMul(p, x); + // Trunc back to half if needed. + if (prim_type == F16) { + x = b_->CreateFPTrunc(x, b_->getHalfTy()); + } + return x; } StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, |