diff options
author | 2018-07-02 07:35:19 -0700 | |
---|---|---|
committer | 2018-07-02 07:38:15 -0700 | |
commit | f52346cb9b72a4109bd2ff9198b4d7588758758c (patch) | |
tree | 235f3d4995425b3194d8bcdf196ceafb53571638 | |
parent | e0ebc3dc4f64d84c6acd5f2ff3574e7f2b3a8fbf (diff) |
[XLA] Rename {SqrtF32, SquareF32, ReciprocalF32} to {Sqrt, Square, Reciprocal} and move them to a new client library xla/client/lib/math.h. Remove the F32 type constraint.
Add an xla::Rqsrt function.
Move {Erf, Erfc, ErfInv, EvaluatePolynomial} to the same library.
[TF:XLA] Update many places in the bridge to use the new functions. Rewrite many of the training ops in operator notation.
PiperOrigin-RevId: 202948474
23 files changed, 411 insertions, 395 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 93d0e22d4a..a8eb7d942d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -122,6 +122,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index f1059856c8..a6f5769e7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 2e5d61e111..f3e112c7b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2))); + handle = handle - ctx->Input(1) * ctx->Input(2); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -94,14 +96,13 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = xla::Add(xla::Mul(accum, momentum), grad); + accum = accum * momentum + grad; if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = xla::Sub(var, xla::Add(xla::Mul(grad, lr), - xla::Mul(xla::Mul(accum, momentum), lr))); + var = var - (grad * lr + accum * momentum * lr); } else { - var = xla::Sub(var, xla::Mul(accum, lr)); + var = var - accum * lr; } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -118,8 +119,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -146,12 +145,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = - xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = xla::Sub( - var, - xla::Mul(xla::Mul(grad, lr), - xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = accum + xla::Square(grad); + var = var - grad * lr * xla::Rsqrt(accum); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -226,18 +221,12 @@ class ResourceApplyAdam : public XlaOpKernel { // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::XlaOp alpha = - xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)), - xla::Sub(one, beta1_power)); - m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1))); - v = xla::Add( - v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2))); - var = xla::Sub(var, xla::Div(xla::Mul(m, alpha), - xla::Add(xla::Pow(v, half), epsilon))); + xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); + m = m + (grad - m) * (one - beta1); + v = v + (xla::Square(grad) - v) * (one - beta2); + var = var - m * alpha / (xla::Sqrt(v) + epsilon); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -255,8 +244,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(3); TensorShape var_shape, ms_shape, mom_shape; @@ -320,17 +307,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = xla::Add( - ms, xla::Mul( - xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), - ms), - xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); + xla::XlaOp new_ms = + ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho); xla::XlaOp new_mom = - xla::Add(xla::Mul(mom, momentum), - xla::Mul(xla::Mul(grad, lr), - xla::Pow(xla::Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = xla::Sub(var, new_mom); + mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon); + xla::XlaOp new_var = var - new_mom; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -425,23 +406,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var))); + grad_to_use = grad + two * l2_shrinkage * var; } else { grad_to_use = grad; } - xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power)); - xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power)); - linear = xla::Add( - linear, - xla::Sub(grad_to_use, - xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr), - var))); - xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1); - xla::XlaOp quadratic = - xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2)); - var = xla::Div(xla::Sub(linear_clipped, linear), quadratic); + xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); + xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); + linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); + xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + var = (linear_clipped - linear) / quadratic; accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index ce894f1faa..116a020437 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -52,56 +53,36 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) -XLAJIT_MAKE_UNARY( - Acos, - xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), - xla::Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)), - xla::Add(XlaHelpers::One(b, input_type(0)), x)))); +XLAJIT_MAKE_UNARY(Acos, + xla::ScalarLike(x, 2.0) * + xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x), + xla::ScalarLike(x, 1.0) + x)); // acosh(x) = log(x + sqrt(x^2 - 1)) // = log(x + sqrt((x+1)*(x-1))) -XLAJIT_MAKE_UNARY( - Acosh, - xla::Log(xla::Add( - x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))), - xla::Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +XLAJIT_MAKE_UNARY(Acosh, + xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) * + (x - xla::ScalarLike(x, 1.0))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( - Asin, - xla::Mul( - XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - xla::Atan2(x, - xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), - xla::Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)))))); + Asin, xla::ScalarLike(x, 2.0) * + xla::Atan2(x, xla::ScalarLike(x, 1.0) + + xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x))); // asinh(x) = log(x + sqrt(x^2 + 1)) -XLAJIT_MAKE_UNARY( - Asinh, - xla::Log(xla::Add( - x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +XLAJIT_MAKE_UNARY(Asinh, + xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0)))); -XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0))); // atanh(x) = 0.5 * log((1 + x) / (1 - x)) -XLAJIT_MAKE_UNARY( - Atanh, - xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x), - xla::Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) / + (xla::ScalarLike(x, 1.0) - x)) * + xla::ScalarLike(x, 0.5)); XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); -XLAJIT_MAKE_UNARY(Cosh, - xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); @@ -109,59 +90,53 @@ XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits<double>::infinity()))); +XLAJIT_MAKE_UNARY( + IsInf, + xla::Eq(xla::Abs(x), + xla::ScalarLike(x, std::numeric_limits<double>::infinity()))); XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Log, xla::Log(x)); XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); -XLAJIT_MAKE_UNARY(Neg, xla::Neg(x)); +XLAJIT_MAKE_UNARY(Neg, -x); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); - auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); +xla::XlaOp RoundToEven(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); auto round_val = xla::Floor(x); - auto fraction = xla::Sub(x, round_val); - auto nearest_even_int = - xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x)))); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); auto is_odd = xla::Eq(nearest_even_int, one); return xla::Select(xla::Or(xla::Gt(fraction, half), xla::And(xla::Eq(fraction, half), is_odd)), - xla::Add(round_val, one), round_val); + round_val + one, round_val); } -XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), - -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x)))); +xla::XlaOp Sigmoid(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); } -XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); -XLAJIT_MAKE_UNARY(Sinh, - xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); // softplus(x) = log(1 + exp(x)) // @@ -171,18 +146,14 @@ XLAJIT_MAKE_UNARY(Sinh, // // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) -XLAJIT_MAKE_UNARY(Softplus, - xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))), - xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); +XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) + + xla::Log1p(xla::Exp(-xla::Abs(x)))); // softsign(x) = x / (abs(x) + 1) -XLAJIT_MAKE_UNARY(Softsign, - xla::Div(x, xla::Add(xla::Abs(x), - XlaHelpers::One(b, input_type(0))))); -XLAJIT_MAKE_UNARY(Sqrt, - xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x))); +XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); +XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x)); +XLAJIT_MAKE_UNARY(Square, x* x); +XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x)); XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index cd35a54aa9..dfa3c0595a 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -59,8 +59,8 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 8e6a0d0b19..8ff10fbd3f 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -19,8 +19,8 @@ limitations under the License. #include <limits> #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 1b4706ed20..a6b9b47253 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -65,6 +65,33 @@ xla_test( ) cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = ["math.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "math_test", + srcs = ["math_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( name = "numeric", srcs = ["numeric.cc"], hdrs = ["numeric.h"], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 86f35c0c0d..978fc40f34 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -121,122 +121,4 @@ XlaOp Any(XlaOp predicates) { }); } -namespace { - -// Polynomials for computing erf/erfc. Originally from cephes. -// Note we use float for compatibility across devices, at the cost of some -// precision for 64 bit computations. -// -// Coefficients are in descending order. -std::array<float, 9> kErfcPCoefficient = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; -std::array<float, 9> kErfcQCoefficient = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; -std::array<float, 6> kErfcRCoefficient = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; -std::array<float, 7> kErfcSCoefficient = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; -std::array<float, 5> kErfTCoefficient = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; -std::array<float, 6> kErfUCoefficient = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; -} // namespace - -// Evaluate the polynomial given coefficients and `x`. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice<float> coefficients) { - XlaOp poly = ScalarLike(x, 0.0); - for (float c : coefficients) { - poly = poly * x + ScalarLike(x, c); - } - return poly; -} - -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x) { - XlaOp abs_x = Abs(x); - XlaOp z = Exp(-x * x); - - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); - - XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); - - return Select(Lt(x, ScalarLike(x, 0.0)), Sub(ScalarLike(x, 2.0), y), y); -} - -// Compute a polynomial approximation of the error function. -XlaOp Erf(XlaOp x) { - XlaOp z = x * x; - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); - return x * pt / pu; -} - -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array<float, 9> w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array<float, 9> w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = ScalarLike(x, 1.0); - auto w = -Log((one - x) * (one + x)); - - auto lt = Lt(w, ScalarLike(x, 5.0)); - auto coefficient = [&](int i) { - return Select(lt, - Broadcast(ScalarLike(x, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, w - ScalarLike(x, 2.5), SqrtF32(w) - ScalarLike(x, 3.0)); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = coefficient(i) + p * w; - } - return p * x; - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 4bd3b615df..d0b916e8c8 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -55,20 +55,6 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); -// Evaluate the polynomial given coefficients and `x`. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice<float> coefficients); - -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x); - -// Compute an approximation of the error function. -XlaOp Erf(XlaOp x); - -// Compute an approximation of the inverse of the error function. -XlaOp ErfInv(XlaOp x); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc new file mode 100644 index 0000000000..5587559040 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } + +XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } + +XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); } + +XlaOp Reciprocal(XlaOp operand) { + return Pow(operand, ScalarLike(operand, -1.0)); +} + +namespace { + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array<float, 9> kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array<float, 9> kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array<float, 6> kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array<float, 7> kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array<float, 5> kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array<float, 6> kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice<float> coefficients) { + XlaOp poly = ScalarLike(x, 0.0); + for (float c : coefficients) { + poly = poly * x + ScalarLike(x, c); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x) { + XlaOp abs_x = Abs(x); + XlaOp z = Exp(-x * x); + + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); + + XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); + + return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +} + +// Compute a polynomial approximation of the error function. +XlaOp Erf(XlaOp x) { + XlaOp z = x * x; + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); + return x * pt / pu; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +XlaOp ErfInv(XlaOp x) { + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array<float, 9> w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array<float, 9> w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, + Broadcast(ScalarLike(x, w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h new file mode 100644 index 0000000000..e7c8b50273 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace xla { + +// Computes the square root of 'operand'. +XlaOp Sqrt(XlaOp operand); + +// Computes the reciprocal of the square root of 'operand'. +XlaOp Rsqrt(XlaOp operand); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Evaluates a polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice<float> coefficients); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the error function. +XlaOp Erf(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc new file mode 100644 index 0000000000..1df4e6ea42 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MathTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(MathTest, SqrtF32) { + XlaBuilder builder(TestName()); + Literal zero_literal = Literal::Zero(PrimitiveType::F32); + + std::unique_ptr<GlobalData> zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + +XLA_TEST_F(MathTest, SquareTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1<float>( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Square(x); + + std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41, + 5.29, 25., 0.81, 5.76, 2.56}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, ReciprocalTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1<float>( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reciprocal(x); + + std::vector<float> expected = { + 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, + 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtZeroes) { + XlaBuilder builder(TestName()); + auto x = ConstantR1<float>(&builder, {0.0, -0.0}); + Sqrt(x); + + ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + Sqrt(x); + + std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 70d5311cbc..95342af6a7 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1377,11 +1377,6 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) { }); } -XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); @@ -1412,16 +1407,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0), - /*broadcast_dimensions=*/{}); -} - -XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnaryOp(HloOpcode::kNegate, operand); } @@ -2512,14 +2497,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } -XlaOp SqrtF32(const XlaOp& operand) { - return operand.builder()->SqrtF32(operand); -} - -XlaOp SquareF32(const XlaOp& operand) { - return operand.builder()->SquareF32(operand); -} - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); @@ -2537,10 +2514,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp ReciprocalF32(const XlaOp& operand) { - return operand.builder()->ReciprocalF32(operand); -} - XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } XlaOp Transpose(const XlaOp& operand, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 79fcc5f256..274aba8a31 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -751,16 +751,6 @@ class XlaBuilder { // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - XlaOp SqrtF32(const XlaOp& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); @@ -783,14 +773,6 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -1235,8 +1217,6 @@ class XlaBuilder { friend XlaOp Tanh(const XlaOp& operand); friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); - friend XlaOp SqrtF32(const XlaOp& operand); - friend XlaOp SquareF32(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); @@ -1244,7 +1224,6 @@ class XlaBuilder { PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - friend XlaOp ReciprocalF32(const XlaOp& operand); friend XlaOp Neg(const XlaOp& operand); friend XlaOp Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> permutation); @@ -1833,16 +1812,6 @@ XlaOp Real(const XlaOp& operand); // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); -// Enqueues a float32 sqrt instruction onto the computation. -// (float32 is specified as there is an implicit float32 0.5f constant -// exponent). -XlaOp SqrtF32(const XlaOp& operand); - -// Enqueues a float32 square instruction onto the computation. -// (float32 is specified as there is an implicit float32 2.0f constant -// exponent). -XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); @@ -1863,14 +1832,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); -// Enqueues a float32 reciprocal instruction onto the computation. -// (float32 is specified as there is an implicit float32 -1.0f constant -// exponent). -// -// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the -// shape of the operand. -XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 83834c1ff6..22cc4e2436 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -52,9 +52,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b5ba4e2d42..be55d50b23 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -626,11 +627,11 @@ _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) -_FORWARD_UNOP(SqrtF32) -_FORWARD_UNOP(SquareF32) +_FORWARD_UNOP(Sqrt) +_FORWARD_UNOP(Square) _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) -_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Reciprocal) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e920f8aecd..690ff277e8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -346,11 +346,11 @@ class LocalComputationBuilder { _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) - _FORWARD_UNOP(SqrtF32) - _FORWARD_UNOP(SquareF32) + _FORWARD_UNOP(Sqrt) + _FORWARD_UNOP(Square) _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) - _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Reciprocal) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 76e9e637cd..c44e69e615 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1002,11 +1002,11 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; %unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::SqrtF32; -%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Sqrt; +%unignore xla::swig::LocalComputationBuilder::Square; %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Reciprocal; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; %unignore xla::swig::DestructureLocalShapedBufferTuple; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index abb97d0c6f..27aee634ba 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -99,10 +99,10 @@ _UNARY_OPS = [ 'Cos', 'Sin', 'Tanh', - 'SqrtF32', - 'SquareF32', + 'Sqrt', + 'Square', 'IsFinite', - 'ReciprocalF32', + 'Reciprocal', 'Neg', 'Sort', ] diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 20b2885e90..77d398e5e2 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -886,6 +886,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index d9d7ba1362..217673c8cb 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -118,7 +119,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XlaBuilder builder("square_tesseract_elementwise"); auto x = ConstantLiteral(&builder, input_literal_); - SquareF32(x); + Square(x); using tensorflow::MathUtil; @@ -150,7 +151,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { auto activation_deviations = Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto dev_squares = SquareF32(activation_deviations); + auto dev_squares = Square(activation_deviations); Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3}); std::vector<float> expected = {18, 0.06}; @@ -160,7 +161,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XlaBuilder builder("variance_to_stddev"); auto variance = ConstantR1<float>(&builder, {6.f, .02f}); - SqrtF32(variance); + Sqrt(variance); std::vector<float> expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); @@ -195,20 +196,20 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon}); auto activation_deviations = Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); - auto dev_squares = SquareF32(activation_deviations); + auto dev_squares = Square(activation_deviations); auto sum_of_squares = CheckShape(&builder, Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = Div(sum_of_squares, count); - auto standard_deviation = SqrtF32(variance); + auto standard_deviation = Sqrt(variance); auto standard_deviation_above_epsilon = CheckShape(&builder, Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); - auto normalization_factors = ReciprocalF32(gt_eps); + auto normalization_factors = Reciprocal(gt_eps); auto normalized_input_activations = Mul(activation_deviations, normalization_factors, /*broadcast_dimensions=*/{1}); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index bc994315c3..3afd8c8fc8 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -897,18 +897,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { ComputeAndCompareR0<int32>(&b, 10, {}); } -XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); - - std::unique_ptr<GlobalData> zero_data = - client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - - XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); - SqrtF32(zero); - - ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_); -} XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index c11df7cdf5..79bae22dac 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -135,46 +135,6 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { ComputeAndCompareR1<uint32>(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1<float>( - &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - SquareF32(x); - - std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41, - 5.29, 25., 0.81, 5.76, 2.56}; - ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1<float>( - &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - ReciprocalF32(x); - - std::vector<float> expected = { - 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, - 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; - ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - XlaBuilder builder(TestName()); - auto x = ConstantR1<float>(&builder, {0.0, -0.0}); - SqrtF32(x); - - ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - SqrtF32(x); - - std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; - ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); -} - XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1<float>(&builder, |