aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/arithmetic.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/arithmetic.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc145
1 files changed, 9 insertions, 136 deletions
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 8c314fa61b..1872925aba 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,8 +17,9 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -93,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
});
}
-XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+XlaComputation CreateScalarAndComputation(PrimitiveType type,
+ XlaBuilder* builder) {
return CreateScalarComputation(
- "and", PRED, builder,
+ "and", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return And(lhs, rhs);
});
}
-XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
- return CreateScalarComputation("or", PRED, builder,
+XlaComputation CreateScalarOrComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation("or", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Or(lhs, rhs); });
}
@@ -111,7 +114,7 @@ XlaOp Any(XlaOp predicates) {
XlaBuilder* builder = predicates.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto f = ConstantR0<bool>(builder, false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
+ XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
@@ -120,134 +123,4 @@ XlaOp Any(XlaOp predicates) {
});
}
-namespace {
-XlaOp FloatLiteral(XlaBuilder* b, PrimitiveType data_type, float value) {
- return ConvertElementType(ConstantR0(b, value), data_type);
-}
-
-// Polynomials for computing erf/erfc. Originally from cephes.
-// Note we use float for compatibility across devices, at the cost of some
-// precision for 64 bit computations.
-//
-// Coefficients are in descending order.
-std::array<float, 9> kErfcPCoefficient = {
- 2.46196981473530512524E-10, 5.64189564831068821977E-1,
- 7.46321056442269912687E0, 4.86371970985681366614E1,
- 1.96520832956077098242E2, 5.26445194995477358631E2,
- 9.34528527171957607540E2, 1.02755188689515710272E3,
- 5.57535335369399327526E2};
-std::array<float, 9> kErfcQCoefficient = {
- 1.00000000000000000000E0, 1.32281951154744992508E1,
- 8.67072140885989742329E1, 3.54937778887819891062E2,
- 9.75708501743205489753E2, 1.82390916687909736289E3,
- 2.24633760818710981792E3, 1.65666309194161350182E3,
- 5.57535340817727675546E2};
-std::array<float, 6> kErfcRCoefficient = {
- 5.64189583547755073984E-1, 1.27536670759978104416E0,
- 5.01905042251180477414E0, 6.16021097993053585195E0,
- 7.40974269950448939160E0, 2.97886665372100240670E0};
-std::array<float, 7> kErfcSCoefficient = {
- 1.00000000000000000000E0, 2.26052863220117276590E0,
- 9.39603524938001434673E0, 1.20489539808096656605E1,
- 1.70814450747565897222E1, 9.60896809063285878198E0,
- 3.36907645100081516050E0};
-std::array<float, 5> kErfTCoefficient = {
- 9.60497373987051638749E0, 9.00260197203842689217E1,
- 2.23200534594684319226E3, 7.00332514112805075473E3,
- 5.55923013010394962768E4};
-std::array<float, 6> kErfUCoefficient = {
- 1.00000000000000000000E0, 3.35617141647503099647E1,
- 5.21357949780152679795E2, 4.59432382970980127987E3,
- 2.26290000613890934246E4, 4.92673942608635921086E4};
-} // namespace
-
-// Evaluate the polynomial given coefficients and `x`.
-// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients,
- PrimitiveType data_type) {
- XlaBuilder* b = x.builder();
- XlaOp poly = FloatLiteral(b, data_type, 0.0);
- for (float c : coefficients) {
- poly = Add(Mul(poly, x), FloatLiteral(b, data_type, c));
- }
- return poly;
-}
-
-// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(XlaOp x, PrimitiveType data_type) {
- XlaBuilder* b = x.builder();
- XlaOp zero = FloatLiteral(b, data_type, 0.0);
- XlaOp two = FloatLiteral(b, data_type, 2.0);
- XlaOp eight = FloatLiteral(b, data_type, 8.0);
-
- XlaOp abs_x = Abs(x);
- XlaOp z = Exp(Mul(Neg(x), x));
-
- XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient, data_type);
- XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient, data_type);
- XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient, data_type);
- XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient, data_type);
-
- XlaOp y = Select(Lt(abs_x, eight), Div(Mul(z, pp), pq), Div(Mul(z, pr), ps));
-
- return Select(Lt(x, zero), Sub(two, y), y);
-}
-
-// Compute a polynomial approximation of the error function.
-XlaOp Erf(XlaOp x, PrimitiveType data_type) {
- XlaOp z = Mul(x, x);
- XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type);
- XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient, data_type);
- return Div(Mul(x, pt), pu);
-}
-
-// Approximation for the inverse error function from
-// Giles, M., "Approximating the erfinv function".
-// The approximation has the form:
-// w = -log((1 - x) * (1 + x))
-// if ( w < 5 ) {
-// w = w - 2.5
-// p = sum_{i=1}^n lq[i]*w^i
-// } else {
-// w = sqrt(w) - 3
-// p = sum_{i=1}^n gq[i]*w^i
-// }
-// return p*x
-XlaOp ErfInv(XlaOp x) {
- XlaBuilder* b = x.builder();
- return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = ConstantR0<float>(b, 1.0);
- auto w = Neg(Log(Mul(Sub(one, x), Add(one, x))));
-
- auto lt = Lt(w, ConstantR0<float>(b, 5.0));
- auto coefficient = [&](int i) {
- return Select(
- lt,
- Broadcast(ConstantR0<float>(b, w_less_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())),
- Broadcast(ConstantR0<float>(b, w_greater_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())));
- };
- w = Select(lt, Sub(w, ConstantR0<float>(b, 2.5f)),
- Sub(SqrtF32(w), ConstantR0<float>(b, 3.0f)));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = Add(coefficient(i), Mul(p, w));
- }
- return Mul(p, x);
- });
-}
-
} // namespace xla