aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-06-19 22:07:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 22:10:17 -0700
commit081f30a7bc2a11e2556629a14cdab2c3c313312e (patch)
tree345ec4824ce6f011ef081e3574943dee2b1cd4e1
parent9ab04addfb80cbf9334bb330acee5fca09353d23 (diff)
[TF2XLA] Optimize TruncatedNormalOp
Re-sampling when encountering a rejected value can be quite slow. If we directly use the inverse CDF of the normal distribution, the probit function, we can avoid the need to resample. PiperOrigin-RevId: 201296864
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc77
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc49
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc12
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc53
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h11
6 files changed, 101 insertions, 103 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 8c6366faa6..2e71b00ba6 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -124,7 +124,7 @@ class RandomOpsTest(XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=3e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index a08654b12b..aa4d242a11 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -17,6 +17,8 @@ limitations under the License.
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
+#include <limits>
+
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
@@ -205,53 +207,44 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
- auto out_of_range_mask = [dtype](xla::XlaOp candidate, xla::XlaBuilder* b) {
- xla::XlaOp two_sd = XlaHelpers::FloatLiteral(b, dtype, 2.0);
- return b->Gt(b->Abs(candidate), two_sd);
+ auto normal_cdf = [](double x) {
+ return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
- // The algorithm we're using is roughly:
- //
- // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) {
- // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
- // candidate = select(out_of_range_mask, rng_normal(), candidate)
- // }
- std::vector<xla::XlaOp> initial_values = {
- // The current candidate.
- b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()),
- // The to_resample mask, where 'true' identifies a location in the
- // current candidate that is out of range and must be regenerated.
- b->Broadcast(b->ConstantR0<bool>(true), shape.dim_sizes()),
- // Is any element in the mask true?
- b->ConstantR0<bool>(true)};
- auto condition = [&](gtl::ArraySlice<xla::XlaOp> values,
- xla::XlaBuilder* b) -> xla::StatusOr<xla::XlaOp> {
- // Continue while any element in the mask is true.
- return values[2];
- };
- auto body =
- [&](gtl::ArraySlice<xla::XlaOp> values,
- xla::XlaBuilder* b) -> xla::StatusOr<std::vector<xla::XlaOp>> {
- xla::XlaOp candidate = values[0];
- xla::XlaOp to_resample = values[1];
- xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
- xla::XlaOp stddev = XlaHelpers::One(b, dtype);
- candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape),
- candidate);
- // Compute a new to_resample mask, and determine whether any value is
- // still out of range.
- to_resample = out_of_range_mask(candidate, b);
- TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b));
- return std::vector<xla::XlaOp>{candidate, to_resample, done};
- };
- auto result =
- XlaWhileLoop(condition, body, initial_values, "truncated_normal", b);
- OP_REQUIRES_OK(ctx, result.status());
- ctx->SetOutput(0, result.ValueOrDie()[0]);
+ const double kA = -2.0;
+ const double kB = 2.0;
+ const double kMu = 0.0;
+ const double kSigma = 1.0;
+ const double kAlpha = (kA - kMu) / kSigma;
+ const double kBeta = (kB - kMu) / kSigma;
+ const double kAlphaNormalCdf = normal_cdf(kAlpha);
+ const double kBetaNormalCdf = normal_cdf(kBeta);
+ const double kZ = kBetaNormalCdf - kAlphaNormalCdf;
+
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
+ xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(b, dtype, std::sqrt(2.0));
+ xla::XlaOp min_positive =
+ XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
+
+ xla::XlaOp z = XlaHelpers::FloatLiteral(b, dtype, kZ);
+ xla::XlaOp alpha_normal_cdf =
+ XlaHelpers::FloatLiteral(b, dtype, kAlphaNormalCdf);
+
+ auto uniform = b->RngUniform(min_positive, one, xla_shape);
+ // probit(p) = sqrt(2) * erfinv(2*p-1)
+ auto p = b->Add(alpha_normal_cdf, b->Mul(z, uniform));
+ auto erfinv_input = b->Sub(b->Mul(p, two), one);
+ auto erfinv_or_status = ErfInv(b, erfinv_input);
+ OP_REQUIRES_OK(ctx, erfinv_or_status.status());
+ auto probit = b->Mul(sqrt_2, erfinv_or_status.ValueOrDie());
+ ctx->SetOutput(0, probit);
}
};
-REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"),
+REGISTER_XLA_OP(Name("TruncatedNormal")
+ .CompileTimeConstInput("shape")
+ .TypeConstraint("dtype", DT_FLOAT),
TruncatedNormalOp);
} // anonymous namespace
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index a99d4ddc7c..58c5dc5aa9 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -163,51 +163,6 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
return floats;
}
-// Approximation for the inverse error function from
-// Giles, M., "Approximating the erfinv function".
-// The approximation has the form:
-// w = -log((1 - x) * (1 + x))
-// if ( w < 5 ) {
-// w = w - 2.5
-// p = sum_{i=1}^n lq[i]*w^i
-// } else {
-// w = sqrt(w) - 3
-// p = sum_{i=1}^n gq[i]*w^i
-// }
-// return p*x
-xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x,
- const TensorShape& shape) {
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = b->ConstantR0<float>(1.0);
- auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
-
- auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
- auto coefficient = [&](int i) {
- return b->Select(
- lt,
- b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
- shape.dim_sizes()),
- b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
- shape.dim_sizes()));
- };
- w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
- b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = b->Add(coefficient(i), b->Mul(p, w));
- }
- return b->Mul(p, x);
-}
-
} // namespace
class StatelessRandomUniformOp : public XlaOpKernel {
@@ -259,8 +214,10 @@ class StatelessRandomNormalOp : public XlaOpKernel {
RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
+ auto erfinv_or_status = ErfInv(builder, uniform);
+ OP_REQUIRES_OK(ctx, erfinv_or_status.status());
auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
- ErfInvF32(builder, uniform, shape));
+ erfinv_or_status.ValueOrDie());
ctx->SetOutput(0, normal);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 2521445e86..1d078de211 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -202,9 +202,9 @@ class ErfOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));
- auto y = b->Select(b->Gt(abs_x, one),
- b->Sub(one, ComputeErfc(b, x, primitive_type)),
- ComputeErf(b, x, primitive_type));
+ auto y =
+ b->Select(b->Gt(abs_x, one), b->Sub(one, Erfc(b, x, primitive_type)),
+ Erf(b, x, primitive_type));
ctx->SetOutput(0, y);
}
};
@@ -223,9 +223,9 @@ class ErfcOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));
- auto y = b->Select(b->Lt(abs_x, one),
- b->Sub(one, ComputeErf(b, x, primitive_type)),
- ComputeErfc(b, x, primitive_type));
+ auto y =
+ b->Select(b->Lt(abs_x, one), b->Sub(one, Erf(b, x, primitive_type)),
+ Erfc(b, x, primitive_type));
ctx->SetOutput(0, y);
}
};
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 639f85737f..f095ec9213 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -176,8 +176,8 @@ xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
}
// Compute an approximation of the error function complement (1 - erf(x)).
-xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
- PrimitiveType data_type) {
+xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
xla::XlaOp zero = FloatLiteral(b, data_type, 0.0);
xla::XlaOp two = FloatLiteral(b, data_type, 2.0);
xla::XlaOp eight = FloatLiteral(b, data_type, 8.0);
@@ -197,12 +197,57 @@ xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
}
// Compute a polynomial approximation of the error function.
-xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
- PrimitiveType data_type) {
+xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
xla::XlaOp z = b->Mul(x, x);
xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type);
xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type);
return b->Div(b->Mul(x, pt), pu);
}
+// Approximation for the inverse error function from
+// Giles, M., "Approximating the erfinv function".
+// The approximation has the form:
+// w = -log((1 - x) * (1 + x))
+// if ( w < 5 ) {
+// w = w - 2.5
+// p = sum_{i=1}^n lq[i]*w^i
+// } else {
+// w = sqrt(w) - 3
+// p = sum_{i=1}^n gq[i]*w^i
+// }
+// return p*x
+StatusOr<XlaOp> ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x) {
+ TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
+ constexpr int kDegree = 9;
+ constexpr std::array<float, 9> w_less_than_5_constants = {
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ constexpr std::array<float, 9> w_greater_than_5_constants = {
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+
+ auto one = b->ConstantR0<float>(1.0);
+ auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
+
+ auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
+ auto coefficient = [&](int i) {
+ return b->Select(
+ lt,
+ b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())),
+ b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())));
+ };
+ w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
+ b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
+ auto p = coefficient(0);
+ for (int i = 1; i < kDegree; ++i) {
+ p = b->Add(coefficient(i), b->Mul(p, w));
+ }
+ return b->Mul(p, x);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index f11cc00317..efdcc7e198 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -62,12 +62,15 @@ xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
PrimitiveType data_type);
// Compute an approximation of the error function complement (1 - erf(x)).
-xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
- PrimitiveType data_type);
+xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type);
// Compute an approximation of the error function.
-xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
- PrimitiveType data_type);
+xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type);
+
+// Compute an approximation of the inverse of the error function.
+StatusOr<XlaOp> ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x);
} // namespace xla