diff options
author | 2018-06-19 22:07:22 -0700 | |
---|---|---|
committer | 2018-06-19 22:10:17 -0700 | |
commit | 081f30a7bc2a11e2556629a14cdab2c3c313312e (patch) | |
tree | 345ec4824ce6f011ef081e3574943dee2b1cd4e1 | |
parent | 9ab04addfb80cbf9334bb330acee5fca09353d23 (diff) |
[TF2XLA] Optimize TruncatedNormalOp
Re-sampling when encountering a rejected value can be quite slow.
If we directly use the inverse CDF of the normal distribution, the probit
function, we can avoid the need to resample.
PiperOrigin-RevId: 201296864
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/random_ops.cc | 77 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc | 49 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/arithmetic.cc | 53 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/arithmetic.h | 11 |
6 files changed, 101 insertions, 103 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 8c6366faa6..2e71b00ba6 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -124,7 +124,7 @@ class RandomOpsTest(XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=3e-4) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index a08654b12b..aa4d242a11 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,8 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include <limits> + #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" @@ -205,53 +207,44 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - auto out_of_range_mask = [dtype](xla::XlaOp candidate, xla::XlaBuilder* b) { - xla::XlaOp two_sd = XlaHelpers::FloatLiteral(b, dtype, 2.0); - return b->Gt(b->Abs(candidate), two_sd); + auto normal_cdf = [](double x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; }; - // The algorithm we're using is roughly: - // - // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { - // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd - // candidate = select(out_of_range_mask, rng_normal(), candidate) - // } - std::vector<xla::XlaOp> initial_values = { - // The current candidate. - b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), - // The to_resample mask, where 'true' identifies a location in the - // current candidate that is out of range and must be regenerated. - b->Broadcast(b->ConstantR0<bool>(true), shape.dim_sizes()), - // Is any element in the mask true? - b->ConstantR0<bool>(true)}; - auto condition = [&](gtl::ArraySlice<xla::XlaOp> values, - xla::XlaBuilder* b) -> xla::StatusOr<xla::XlaOp> { - // Continue while any element in the mask is true. - return values[2]; - }; - auto body = - [&](gtl::ArraySlice<xla::XlaOp> values, - xla::XlaBuilder* b) -> xla::StatusOr<std::vector<xla::XlaOp>> { - xla::XlaOp candidate = values[0]; - xla::XlaOp to_resample = values[1]; - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), - candidate); - // Compute a new to_resample mask, and determine whether any value is - // still out of range. - to_resample = out_of_range_mask(candidate, b); - TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); - return std::vector<xla::XlaOp>{candidate, to_resample, done}; - }; - auto result = - XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()[0]); + const double kA = -2.0; + const double kB = 2.0; + const double kMu = 0.0; + const double kSigma = 1.0; + const double kAlpha = (kA - kMu) / kSigma; + const double kBeta = (kB - kMu) / kSigma; + const double kAlphaNormalCdf = normal_cdf(kAlpha); + const double kBetaNormalCdf = normal_cdf(kBeta); + const double kZ = kBetaNormalCdf - kAlphaNormalCdf; + + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(b, dtype, std::sqrt(2.0)); + xla::XlaOp min_positive = + XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min()); + + xla::XlaOp z = XlaHelpers::FloatLiteral(b, dtype, kZ); + xla::XlaOp alpha_normal_cdf = + XlaHelpers::FloatLiteral(b, dtype, kAlphaNormalCdf); + + auto uniform = b->RngUniform(min_positive, one, xla_shape); + // probit(p) = sqrt(2) * erfinv(2*p-1) + auto p = b->Add(alpha_normal_cdf, b->Mul(z, uniform)); + auto erfinv_input = b->Sub(b->Mul(p, two), one); + auto erfinv_or_status = ErfInv(b, erfinv_input); + OP_REQUIRES_OK(ctx, erfinv_or_status.status()); + auto probit = b->Mul(sqrt_2, erfinv_or_status.ValueOrDie()); + ctx->SetOutput(0, probit); } }; -REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("TruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a99d4ddc7c..58c5dc5aa9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -163,51 +163,6 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, return floats; } -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, - const TensorShape& shape) { - constexpr int kDegree = 9; - constexpr std::array<float, 9> w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array<float, 9> w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0<float>(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0<float>(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]), - shape.dim_sizes()), - b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]), - shape.dim_sizes())); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); -} - } // namespace class StatelessRandomUniformOp : public XlaOpKernel { @@ -259,8 +214,10 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) + auto erfinv_or_status = ErfInv(builder, uniform); + OP_REQUIRES_OK(ctx, erfinv_or_status.status()); auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)), - ErfInvF32(builder, uniform, shape)); + erfinv_or_status.ValueOrDie()); ctx->SetOutput(0, normal); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 2521445e86..1d078de211 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -202,9 +202,9 @@ class ErfOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Gt(abs_x, one), - b->Sub(one, ComputeErfc(b, x, primitive_type)), - ComputeErf(b, x, primitive_type)); + auto y = + b->Select(b->Gt(abs_x, one), b->Sub(one, Erfc(b, x, primitive_type)), + Erf(b, x, primitive_type)); ctx->SetOutput(0, y); } }; @@ -223,9 +223,9 @@ class ErfcOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Lt(abs_x, one), - b->Sub(one, ComputeErf(b, x, primitive_type)), - ComputeErfc(b, x, primitive_type)); + auto y = + b->Select(b->Lt(abs_x, one), b->Sub(one, Erf(b, x, primitive_type)), + Erfc(b, x, primitive_type)); ctx->SetOutput(0, y); } }; diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 639f85737f..f095ec9213 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -176,8 +176,8 @@ xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, } // Compute an approximation of the error function complement (1 - erf(x)). -xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type) { +xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { xla::XlaOp zero = FloatLiteral(b, data_type, 0.0); xla::XlaOp two = FloatLiteral(b, data_type, 2.0); xla::XlaOp eight = FloatLiteral(b, data_type, 8.0); @@ -197,12 +197,57 @@ xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, } // Compute a polynomial approximation of the error function. -xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type) { +xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { xla::XlaOp z = b->Mul(x, x); xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type); xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type); return b->Div(b->Mul(x, pt), pu); } +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +StatusOr<XlaOp> ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array<float, 9> w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array<float, 9> w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0<float>(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0<float>(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f11cc00317..efdcc7e198 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -62,12 +62,15 @@ xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, PrimitiveType data_type); // Compute an approximation of the error function complement (1 - erf(x)). -xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type); +xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); // Compute an approximation of the error function. -xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type); +xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); + +// Compute an approximation of the inverse of the error function. +StatusOr<XlaOp> ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x); } // namespace xla |