aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/arithmetic.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/arithmetic.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc84
1 files changed, 84 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index a1d34796cc..639f85737f 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -121,4 +121,88 @@ StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
return builder->Reduce(predicates, f, logical_or, all_dimensions);
}
+namespace {
+xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type,
+ float value) {
+ return b->ConvertElementType(b->ConstantR0(value), data_type);
+}
+
+// Polynomials for computing erf/erfc. Originally from cephes.
+// Note we use float for compatibility across devices, at the cost of some
+// precision for 64 bit computations.
+//
+// Coefficients are in descending order.
+std::array<float, 9> kErfcPCoefficient = {
+ 2.46196981473530512524E-10, 5.64189564831068821977E-1,
+ 7.46321056442269912687E0, 4.86371970985681366614E1,
+ 1.96520832956077098242E2, 5.26445194995477358631E2,
+ 9.34528527171957607540E2, 1.02755188689515710272E3,
+ 5.57535335369399327526E2};
+std::array<float, 9> kErfcQCoefficient = {
+ 1.00000000000000000000E0, 1.32281951154744992508E1,
+ 8.67072140885989742329E1, 3.54937778887819891062E2,
+ 9.75708501743205489753E2, 1.82390916687909736289E3,
+ 2.24633760818710981792E3, 1.65666309194161350182E3,
+ 5.57535340817727675546E2};
+std::array<float, 6> kErfcRCoefficient = {
+ 5.64189583547755073984E-1, 1.27536670759978104416E0,
+ 5.01905042251180477414E0, 6.16021097993053585195E0,
+ 7.40974269950448939160E0, 2.97886665372100240670E0};
+std::array<float, 7> kErfcSCoefficient = {
+ 1.00000000000000000000E0, 2.26052863220117276590E0,
+ 9.39603524938001434673E0, 1.20489539808096656605E1,
+ 1.70814450747565897222E1, 9.60896809063285878198E0,
+ 3.36907645100081516050E0};
+std::array<float, 5> kErfTCoefficient = {
+ 9.60497373987051638749E0, 9.00260197203842689217E1,
+ 2.23200534594684319226E3, 7.00332514112805075473E3,
+ 5.55923013010394962768E4};
+std::array<float, 6> kErfUCoefficient = {
+ 1.00000000000000000000E0, 3.35617141647503099647E1,
+ 5.21357949780152679795E2, 4.59432382970980127987E3,
+ 2.26290000613890934246E4, 4.92673942608635921086E4};
+} // namespace
+
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
+ tensorflow::gtl::ArraySlice<float> coefficients,
+ PrimitiveType data_type) {
+ xla::XlaOp poly = FloatLiteral(b, data_type, 0.0);
+ for (float c : coefficients) {
+ poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c));
+ }
+ return poly;
+}
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
+ xla::XlaOp zero = FloatLiteral(b, data_type, 0.0);
+ xla::XlaOp two = FloatLiteral(b, data_type, 2.0);
+ xla::XlaOp eight = FloatLiteral(b, data_type, 8.0);
+
+ xla::XlaOp abs_x = b->Abs(x);
+ xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x));
+
+ xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type);
+ xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type);
+ xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type);
+ xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type);
+
+ xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq),
+ b->Div(b->Mul(z, pr), ps));
+
+ return b->Select(b->Lt(x, zero), b->Sub(two, y), y);
+}
+
+// Compute a polynomial approximation of the error function.
+xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
+ xla::XlaOp z = b->Mul(x, x);
+ xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type);
+ xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type);
+ return b->Div(b->Mul(x, pt), pu);
+}
+
} // namespace xla