aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-07-02 07:35:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 07:38:15 -0700
commitf52346cb9b72a4109bd2ff9198b4d7588758758c (patch)
tree235f3d4995425b3194d8bcdf196ceafb53571638
parente0ebc3dc4f64d84c6acd5f2ff3574e7f2b3a8fbf (diff)
[XLA] Rename {SqrtF32, SquareF32, ReciprocalF32} to {Sqrt, Square, Reciprocal} and move them to a new client library xla/client/lib/math.h. Remove the F32 type constraint.
Add an xla::Rqsrt function. Move {Erf, Erfc, ErfInv, EvaluatePolynomial} to the same library. [TF:XLA] Update many places in the bridge to use the new functions. Rewrite many of the training ops in operator notation. PiperOrigin-RevId: 202948474
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc72
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc121
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD27
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc118
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h14
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc152
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h51
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc85
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc27
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h39
-rw-r--r--tensorflow/compiler/xla/python/BUILD2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc7
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i6
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py6
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc40
23 files changed, 411 insertions, 395 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 93d0e22d4a..a8eb7d942d 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -122,6 +122,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index f1059856c8..a6f5769e7b 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 2e5d61e111..f3e112c7b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
- handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2)));
+ handle = handle - ctx->Input(1) * ctx->Input(2);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -94,14 +96,13 @@ class ResourceApplyMomentum : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(3);
xla::XlaOp momentum = ctx->Input(4);
- accum = xla::Add(xla::Mul(accum, momentum), grad);
+ accum = accum * momentum + grad;
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
- var = xla::Sub(var, xla::Add(xla::Mul(grad, lr),
- xla::Mul(xla::Mul(accum, momentum), lr)));
+ var = var - (grad * lr + accum * momentum * lr);
} else {
- var = xla::Sub(var, xla::Mul(accum, lr));
+ var = var - accum * lr;
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
@@ -118,8 +119,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -146,12 +145,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
- accum =
- xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
- var = xla::Sub(
- var,
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
+ accum = accum + xla::Square(grad);
+ var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
@@ -226,18 +221,12 @@ class ResourceApplyAdam : public XlaOpKernel {
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::XlaOp alpha =
- xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)),
- xla::Sub(one, beta1_power));
- m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1)));
- v = xla::Add(
- v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2)));
- var = xla::Sub(var, xla::Div(xla::Mul(m, alpha),
- xla::Add(xla::Pow(v, half), epsilon)));
+ xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
+ m = m + (grad - m) * (one - beta1);
+ v = v + (xla::Square(grad) - v) * (one - beta2);
+ var = var - m * alpha / (xla::Sqrt(v) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
@@ -255,8 +244,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(3);
TensorShape var_shape, ms_shape, mom_shape;
@@ -320,17 +307,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms = xla::Add(
- ms, xla::Mul(
- xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)),
- ms),
- xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
+ xla::XlaOp new_ms =
+ ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho);
xla::XlaOp new_mom =
- xla::Add(xla::Mul(mom, momentum),
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(xla::Add(new_ms, epsilon),
- XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::XlaOp new_var = xla::Sub(var, new_mom);
+ mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon);
+ xla::XlaOp new_var = var - new_mom;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
@@ -425,23 +406,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
- grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var)));
+ grad_to_use = grad + two * l2_shrinkage * var;
} else {
grad_to_use = grad;
}
- xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two));
- xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power));
- xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power));
- linear = xla::Add(
- linear,
- xla::Sub(grad_to_use,
- xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr),
- var)));
- xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1);
- xla::XlaOp quadratic =
- xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2));
- var = xla::Div(xla::Sub(linear_clipped, linear), quadratic);
+ xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
+ xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
+ linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
+ xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1);
+ xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2;
+ var = (linear_clipped - linear) / quadratic;
accum = new_accum;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index ce894f1faa..116a020437 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -52,56 +53,36 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x));
XLAJIT_MAKE_UNARY(Abs, xla::Abs(x));
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
-XLAJIT_MAKE_UNARY(
- Acos,
- xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
- xla::Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0),
- 0.5)),
- xla::Add(XlaHelpers::One(b, input_type(0)), x))));
+XLAJIT_MAKE_UNARY(Acos,
+ xla::ScalarLike(x, 2.0) *
+ xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x),
+ xla::ScalarLike(x, 1.0) + x));
// acosh(x) = log(x + sqrt(x^2 - 1))
// = log(x + sqrt((x+1)*(x-1)))
-XLAJIT_MAKE_UNARY(
- Acosh,
- xla::Log(xla::Add(
- x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))),
- xla::Sub(x, XlaHelpers::One(b, input_type(0)))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Acosh,
+ xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) *
+ (x - xla::ScalarLike(x, 1.0)))));
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XLAJIT_MAKE_UNARY(
- Asin,
- xla::Mul(
- XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- xla::Atan2(x,
- xla::Add(XlaHelpers::One(b, input_type(0)),
- xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
- xla::Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0),
- 0.5))))));
+ Asin, xla::ScalarLike(x, 2.0) *
+ xla::Atan2(x, xla::ScalarLike(x, 1.0) +
+ xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x)));
// asinh(x) = log(x + sqrt(x^2 + 1))
-XLAJIT_MAKE_UNARY(
- Asinh,
- xla::Log(xla::Add(
- x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Asinh,
+ xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0))));
-XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0))));
+XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0)));
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
-XLAJIT_MAKE_UNARY(
- Atanh,
- xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x),
- xla::Sub(XlaHelpers::One(b, input_type(0)), x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) /
+ (xla::ScalarLike(x, 1.0) - x)) *
+ xla::ScalarLike(x, 0.5));
XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x));
XLAJIT_MAKE_UNARY(Cos, xla::Cos(x));
-XLAJIT_MAKE_UNARY(Cosh,
- xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
XLAJIT_MAKE_UNARY(Sin, xla::Sin(x));
XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
@@ -109,59 +90,53 @@ XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
-XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x),
- XlaHelpers::FloatLiteral(
- b, input_type(0),
- std::numeric_limits<double>::infinity())));
+XLAJIT_MAKE_UNARY(
+ IsInf,
+ xla::Eq(xla::Abs(x),
+ xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
// Return 1/x
-XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
+XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
XLAJIT_MAKE_UNARY(Log, xla::Log(x));
XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
-XLAJIT_MAKE_UNARY(Neg, xla::Neg(x));
+XLAJIT_MAKE_UNARY(Neg, -x);
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
-static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
- auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
+xla::XlaOp RoundToEven(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ auto one = xla::ScalarLike(x, 1.0);
+ auto two = xla::ScalarLike(x, 2.0);
auto round_val = xla::Floor(x);
- auto fraction = xla::Sub(x, round_val);
- auto nearest_even_int =
- xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x))));
+ auto fraction = x - round_val;
+ auto nearest_even_int = round_val - two * xla::Floor(half * x);
auto is_odd = xla::Eq(nearest_even_int, one);
return xla::Select(xla::Or(xla::Gt(fraction, half),
xla::And(xla::Eq(fraction, half), is_odd)),
- xla::Add(round_val, one), round_val);
+ round_val + one, round_val);
}
-XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
-XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Rint, RoundToEven(x));
+XLAJIT_MAKE_UNARY(Round, RoundToEven(x));
-XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0),
- -0.5)));
+XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
-static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x))));
+xla::XlaOp Sigmoid(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ return half + half * xla::Tanh(half * x);
}
-XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
-XLAJIT_MAKE_UNARY(Sinh,
- xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
// softplus(x) = log(1 + exp(x))
//
@@ -171,18 +146,14 @@ XLAJIT_MAKE_UNARY(Sinh,
//
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
-XLAJIT_MAKE_UNARY(Softplus,
- xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))),
- xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x))))));
+XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
+ xla::Log1p(xla::Exp(-xla::Abs(x))));
// softsign(x) = x / (abs(x) + 1)
-XLAJIT_MAKE_UNARY(Softsign,
- xla::Div(x, xla::Add(xla::Abs(x),
- XlaHelpers::One(b, input_type(0)))));
-XLAJIT_MAKE_UNARY(Sqrt,
- xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x));
-XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x)));
+XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
+XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x));
+XLAJIT_MAKE_UNARY(Square, x* x);
+XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x));
XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x));
XLAJIT_MAKE_UNARY(Real, xla::Real(x));
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index cd35a54aa9..dfa3c0595a 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -59,8 +59,8 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
index 8e6a0d0b19..8ff10fbd3f 100644
--- a/tensorflow/compiler/tf2xla/lib/random.cc
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <limits>
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
-#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 1b4706ed20..a6b9b47253 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -65,6 +65,33 @@ xla_test(
)
cc_library(
+ name = "math",
+ srcs = ["math.cc"],
+ hdrs = ["math.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "math_test",
+ srcs = ["math_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":math",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "numeric",
srcs = ["numeric.cc"],
hdrs = ["numeric.h"],
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 86f35c0c0d..978fc40f34 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -121,122 +121,4 @@ XlaOp Any(XlaOp predicates) {
});
}
-namespace {
-
-// Polynomials for computing erf/erfc. Originally from cephes.
-// Note we use float for compatibility across devices, at the cost of some
-// precision for 64 bit computations.
-//
-// Coefficients are in descending order.
-std::array<float, 9> kErfcPCoefficient = {
- 2.46196981473530512524E-10, 5.64189564831068821977E-1,
- 7.46321056442269912687E0, 4.86371970985681366614E1,
- 1.96520832956077098242E2, 5.26445194995477358631E2,
- 9.34528527171957607540E2, 1.02755188689515710272E3,
- 5.57535335369399327526E2};
-std::array<float, 9> kErfcQCoefficient = {
- 1.00000000000000000000E0, 1.32281951154744992508E1,
- 8.67072140885989742329E1, 3.54937778887819891062E2,
- 9.75708501743205489753E2, 1.82390916687909736289E3,
- 2.24633760818710981792E3, 1.65666309194161350182E3,
- 5.57535340817727675546E2};
-std::array<float, 6> kErfcRCoefficient = {
- 5.64189583547755073984E-1, 1.27536670759978104416E0,
- 5.01905042251180477414E0, 6.16021097993053585195E0,
- 7.40974269950448939160E0, 2.97886665372100240670E0};
-std::array<float, 7> kErfcSCoefficient = {
- 1.00000000000000000000E0, 2.26052863220117276590E0,
- 9.39603524938001434673E0, 1.20489539808096656605E1,
- 1.70814450747565897222E1, 9.60896809063285878198E0,
- 3.36907645100081516050E0};
-std::array<float, 5> kErfTCoefficient = {
- 9.60497373987051638749E0, 9.00260197203842689217E1,
- 2.23200534594684319226E3, 7.00332514112805075473E3,
- 5.55923013010394962768E4};
-std::array<float, 6> kErfUCoefficient = {
- 1.00000000000000000000E0, 3.35617141647503099647E1,
- 5.21357949780152679795E2, 4.59432382970980127987E3,
- 2.26290000613890934246E4, 4.92673942608635921086E4};
-} // namespace
-
-// Evaluate the polynomial given coefficients and `x`.
-// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients) {
- XlaOp poly = ScalarLike(x, 0.0);
- for (float c : coefficients) {
- poly = poly * x + ScalarLike(x, c);
- }
- return poly;
-}
-
-// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(XlaOp x) {
- XlaOp abs_x = Abs(x);
- XlaOp z = Exp(-x * x);
-
- XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient);
- XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient);
- XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient);
- XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient);
-
- XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps);
-
- return Select(Lt(x, ScalarLike(x, 0.0)), Sub(ScalarLike(x, 2.0), y), y);
-}
-
-// Compute a polynomial approximation of the error function.
-XlaOp Erf(XlaOp x) {
- XlaOp z = x * x;
- XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient);
- XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient);
- return x * pt / pu;
-}
-
-// Approximation for the inverse error function from
-// Giles, M., "Approximating the erfinv function".
-// The approximation has the form:
-// w = -log((1 - x) * (1 + x))
-// if ( w < 5 ) {
-// w = w - 2.5
-// p = sum_{i=1}^n lq[i]*w^i
-// } else {
-// w = sqrt(w) - 3
-// p = sum_{i=1}^n gq[i]*w^i
-// }
-// return p*x
-XlaOp ErfInv(XlaOp x) {
- XlaBuilder* b = x.builder();
- return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = ScalarLike(x, 1.0);
- auto w = -Log((one - x) * (one + x));
-
- auto lt = Lt(w, ScalarLike(x, 5.0));
- auto coefficient = [&](int i) {
- return Select(lt,
- Broadcast(ScalarLike(x, w_less_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())),
- Broadcast(ScalarLike(x, w_greater_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())));
- };
- w = Select(lt, w - ScalarLike(x, 2.5), SqrtF32(w) - ScalarLike(x, 3.0));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = coefficient(i) + p * w;
- }
- return p * x;
- });
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 4bd3b615df..d0b916e8c8 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -55,20 +55,6 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Note: if predicates is zero-sized, Any() vacuously returns false.
XlaOp Any(XlaOp predicates);
-// Evaluate the polynomial given coefficients and `x`.
-// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients);
-
-// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(XlaOp x);
-
-// Compute an approximation of the error function.
-XlaOp Erf(XlaOp x);
-
-// Compute an approximation of the inverse of the error function.
-XlaOp ErfInv(XlaOp x);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
new file mode 100644
index 0000000000..5587559040
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace xla {
+
+XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
+
+XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
+
+XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); }
+
+XlaOp Reciprocal(XlaOp operand) {
+ return Pow(operand, ScalarLike(operand, -1.0));
+}
+
+namespace {
+
+// Polynomials for computing erf/erfc. Originally from cephes.
+// Note we use float for compatibility across devices, at the cost of some
+// precision for 64 bit computations.
+//
+// Coefficients are in descending order.
+std::array<float, 9> kErfcPCoefficient = {
+ 2.46196981473530512524E-10, 5.64189564831068821977E-1,
+ 7.46321056442269912687E0, 4.86371970985681366614E1,
+ 1.96520832956077098242E2, 5.26445194995477358631E2,
+ 9.34528527171957607540E2, 1.02755188689515710272E3,
+ 5.57535335369399327526E2};
+std::array<float, 9> kErfcQCoefficient = {
+ 1.00000000000000000000E0, 1.32281951154744992508E1,
+ 8.67072140885989742329E1, 3.54937778887819891062E2,
+ 9.75708501743205489753E2, 1.82390916687909736289E3,
+ 2.24633760818710981792E3, 1.65666309194161350182E3,
+ 5.57535340817727675546E2};
+std::array<float, 6> kErfcRCoefficient = {
+ 5.64189583547755073984E-1, 1.27536670759978104416E0,
+ 5.01905042251180477414E0, 6.16021097993053585195E0,
+ 7.40974269950448939160E0, 2.97886665372100240670E0};
+std::array<float, 7> kErfcSCoefficient = {
+ 1.00000000000000000000E0, 2.26052863220117276590E0,
+ 9.39603524938001434673E0, 1.20489539808096656605E1,
+ 1.70814450747565897222E1, 9.60896809063285878198E0,
+ 3.36907645100081516050E0};
+std::array<float, 5> kErfTCoefficient = {
+ 9.60497373987051638749E0, 9.00260197203842689217E1,
+ 2.23200534594684319226E3, 7.00332514112805075473E3,
+ 5.55923013010394962768E4};
+std::array<float, 6> kErfUCoefficient = {
+ 1.00000000000000000000E0, 3.35617141647503099647E1,
+ 5.21357949780152679795E2, 4.59432382970980127987E3,
+ 2.26290000613890934246E4, 4.92673942608635921086E4};
+} // namespace
+
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients) {
+ XlaOp poly = ScalarLike(x, 0.0);
+ for (float c : coefficients) {
+ poly = poly * x + ScalarLike(x, c);
+ }
+ return poly;
+}
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x) {
+ XlaOp abs_x = Abs(x);
+ XlaOp z = Exp(-x * x);
+
+ XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient);
+ XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient);
+ XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient);
+ XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient);
+
+ XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps);
+
+ return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y);
+}
+
+// Compute a polynomial approximation of the error function.
+XlaOp Erf(XlaOp x) {
+ XlaOp z = x * x;
+ XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient);
+ XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient);
+ return x * pt / pu;
+}
+
+// Approximation for the inverse error function from
+// Giles, M., "Approximating the erfinv function".
+// The approximation has the form:
+// w = -log((1 - x) * (1 + x))
+// if ( w < 5 ) {
+// w = w - 2.5
+// p = sum_{i=1}^n lq[i]*w^i
+// } else {
+// w = sqrt(w) - 3
+// p = sum_{i=1}^n gq[i]*w^i
+// }
+// return p*x
+XlaOp ErfInv(XlaOp x) {
+ XlaBuilder* b = x.builder();
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
+ constexpr int kDegree = 9;
+ constexpr std::array<float, 9> w_less_than_5_constants = {
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ constexpr std::array<float, 9> w_greater_than_5_constants = {
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+
+ auto one = ScalarLike(x, 1.0);
+ auto w = -Log((one - x) * (one + x));
+
+ auto lt = Lt(w, ScalarLike(x, 5.0));
+ auto coefficient = [&](int i) {
+ return Select(lt,
+ Broadcast(ScalarLike(x, w_less_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())),
+ Broadcast(ScalarLike(x, w_greater_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())));
+ };
+ w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
+ auto p = coefficient(0);
+ for (int i = 1; i < kDegree; ++i) {
+ p = coefficient(i) + p * w;
+ }
+ return p * x;
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
new file mode 100644
index 0000000000..e7c8b50273
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace xla {
+
+// Computes the square root of 'operand'.
+XlaOp Sqrt(XlaOp operand);
+
+// Computes the reciprocal of the square root of 'operand'.
+XlaOp Rsqrt(XlaOp operand);
+
+// Computes the square of 'operand'.
+XlaOp Square(XlaOp operand);
+
+// Computes the reciprocal of 'operand'.
+XlaOp Reciprocal(XlaOp operand);
+
+// Evaluates a polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients);
+
+// Computes an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x);
+
+// Computes an approximation of the error function.
+XlaOp Erf(XlaOp x);
+
+// Computes an approximation of the inverse of the error function.
+XlaOp ErfInv(XlaOp x);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
new file mode 100644
index 0000000000..1df4e6ea42
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class MathTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(MathTest, SqrtF32) {
+ XlaBuilder builder(TestName());
+ Literal zero_literal = Literal::Zero(PrimitiveType::F32);
+
+ std::unique_ptr<GlobalData> zero_data =
+ client_->TransferToServer(zero_literal).ConsumeValueOrDie();
+
+ XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
+ Sqrt(zero);
+
+ ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SquareTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Square(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, ReciprocalTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reciprocal(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtZeroes) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {0.0, -0.0});
+ Sqrt(x);
+
+ ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtSixValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
+ Sqrt(x);
+
+ std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 70d5311cbc..95342af6a7 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1377,11 +1377,6 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
});
}
-XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
@@ -1412,16 +1407,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::SquareF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Neg(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNegate, operand);
}
@@ -2512,14 +2497,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
-XlaOp SqrtF32(const XlaOp& operand) {
- return operand.builder()->SqrtF32(operand);
-}
-
-XlaOp SquareF32(const XlaOp& operand) {
- return operand.builder()->SquareF32(operand);
-}
-
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
@@ -2537,10 +2514,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
-XlaOp ReciprocalF32(const XlaOp& operand) {
- return operand.builder()->ReciprocalF32(operand);
-}
-
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
XlaOp Transpose(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 79fcc5f256..274aba8a31 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -751,16 +751,6 @@ class XlaBuilder {
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- XlaOp SqrtF32(const XlaOp& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -783,14 +773,6 @@ class XlaBuilder {
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
@@ -1235,8 +1217,6 @@ class XlaBuilder {
friend XlaOp Tanh(const XlaOp& operand);
friend XlaOp Real(const XlaOp& operand);
friend XlaOp Imag(const XlaOp& operand);
- friend XlaOp SqrtF32(const XlaOp& operand);
- friend XlaOp SquareF32(const XlaOp& operand);
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
@@ -1244,7 +1224,6 @@ class XlaBuilder {
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- friend XlaOp ReciprocalF32(const XlaOp& operand);
friend XlaOp Neg(const XlaOp& operand);
friend XlaOp Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation);
@@ -1833,16 +1812,6 @@ XlaOp Real(const XlaOp& operand);
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
-// Enqueues a float32 sqrt instruction onto the computation.
-// (float32 is specified as there is an implicit float32 0.5f constant
-// exponent).
-XlaOp SqrtF32(const XlaOp& operand);
-
-// Enqueues a float32 square instruction onto the computation.
-// (float32 is specified as there is an implicit float32 2.0f constant
-// exponent).
-XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -1863,14 +1832,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
// identical.
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
-// Enqueues a float32 reciprocal instruction onto the computation.
-// (float32 is specified as there is an implicit float32 -1.0f constant
-// exponent).
-//
-// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
-// shape of the operand.
-XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 83834c1ff6..22cc4e2436 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -52,9 +52,9 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
- "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index b5ba4e2d42..be55d50b23 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -626,11 +627,11 @@ _FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
-_FORWARD_UNOP(SqrtF32)
-_FORWARD_UNOP(SquareF32)
+_FORWARD_UNOP(Sqrt)
+_FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
-_FORWARD_UNOP(ReciprocalF32)
+_FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index e920f8aecd..690ff277e8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -346,11 +346,11 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
- _FORWARD_UNOP(SqrtF32)
- _FORWARD_UNOP(SquareF32)
+ _FORWARD_UNOP(Sqrt)
+ _FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
- _FORWARD_UNOP(ReciprocalF32)
+ _FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 76e9e637cd..c44e69e615 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -1002,11 +1002,11 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh;
-%unignore xla::swig::LocalComputationBuilder::SqrtF32;
-%unignore xla::swig::LocalComputationBuilder::SquareF32;
+%unignore xla::swig::LocalComputationBuilder::Sqrt;
+%unignore xla::swig::LocalComputationBuilder::Square;
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::IsFinite;
-%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
+%unignore xla::swig::LocalComputationBuilder::Reciprocal;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index abb97d0c6f..27aee634ba 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -99,10 +99,10 @@ _UNARY_OPS = [
'Cos',
'Sin',
'Tanh',
- 'SqrtF32',
- 'SquareF32',
+ 'Sqrt',
+ 'Square',
'IsFinite',
- 'ReciprocalF32',
+ 'Reciprocal',
'Neg',
'Sort',
]
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 20b2885e90..77d398e5e2 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -886,6 +886,7 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index d9d7ba1362..217673c8cb 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
@@ -118,7 +119,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
XlaBuilder builder("square_tesseract_elementwise");
auto x = ConstantLiteral(&builder, input_literal_);
- SquareF32(x);
+ Square(x);
using tensorflow::MathUtil;
@@ -150,7 +151,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
std::vector<float> expected = {18, 0.06};
@@ -160,7 +161,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XlaBuilder builder("variance_to_stddev");
auto variance = ConstantR1<float>(&builder, {6.f, .02f});
- SqrtF32(variance);
+ Sqrt(variance);
std::vector<float> expected = {2.44948974f, 0.14142136f};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -195,20 +196,20 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
auto sum_of_squares =
CheckShape(&builder,
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto variance = Div(sum_of_squares, count);
- auto standard_deviation = SqrtF32(variance);
+ auto standard_deviation = Sqrt(variance);
auto standard_deviation_above_epsilon =
CheckShape(&builder, Gt(standard_deviation, epsilon),
ShapeUtil::MakeShape(PRED, {2}));
auto gt_eps =
Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
- auto normalization_factors = ReciprocalF32(gt_eps);
+ auto normalization_factors = Reciprocal(gt_eps);
auto normalized_input_activations =
Mul(activation_deviations, normalization_factors,
/*broadcast_dimensions=*/{1});
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index bc994315c3..3afd8c8fc8 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -897,18 +897,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
ComputeAndCompareR0<int32>(&b, 10, {});
}
-XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
- XlaBuilder builder(TestName());
- Literal zero_literal = Literal::Zero(PrimitiveType::F32);
-
- std::unique_ptr<GlobalData> zero_data =
- client_->TransferToServer(zero_literal).ConsumeValueOrDie();
-
- XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
- SqrtF32(zero);
-
- ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
-}
XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index c11df7cdf5..79bae22dac 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -135,46 +135,6 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
-XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(
- &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- SquareF32(x);
-
- std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
- 5.29, 25., 0.81, 5.76, 2.56};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(
- &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- ReciprocalF32(x);
-
- std::vector<float> expected = {
- 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
- 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(&builder, {0.0, -0.0});
- SqrtF32(x);
-
- ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
- SqrtF32(x);
-
- std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder,