diff options
author | Peter Hawkins <phawkins@google.com> | 2018-07-13 18:18:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-13 18:21:51 -0700 |
commit | 4424e3270e4056ef7318fbdd83727cb93bec6858 (patch) | |
tree | ca9171110968f2eb01f59491c8c74d4c46495267 | |
parent | d722c3e93fa180e4dad7678cf32868ed18f6ef84 (diff) |
[XLA] Move implementation of ThreeFry stateless PRNG into xla/client/lib
PiperOrigin-RevId: 204557470
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc | 168 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/prng.cc | 150 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/prng.h | 34 |
5 files changed, 234 insertions, 134 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5a335aa43c..d88a34dfd9 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -127,6 +127,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a6f5769e7b..cc4b13d3b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,134 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -// Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, - int distance) { - return xla::Or( - xla::ShiftLeft(v, xla::ConstantR0<int>(builder, distance)), - xla::ShiftRightLogical(v, xla::ConstantR0<int>(builder, 32 - distance))); -} - -using ThreeFry2x32State = std::array<xla::XlaOp, 2>; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, - ThreeFry2x32State input, ThreeFry2x32State key) { - // Rotation distances specified by the Threefry2x32 algorithm. - constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ThreeFry2x32State x; - - std::array<xla::XlaOp, 3> ks; - // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = xla::ConstantR0<int32>(builder, 0x1BD11BDA); - for (int i = 0; i < 2; ++i) { - ks[i] = key[i]; - x[i] = input[i]; - ks[2] = xla::Xor(ks[2], key[i]); - } - - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(x[1], ks[1]); - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = xla::Add(v[0], v[1]); - v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = xla::Xor(v[0], v[1]); - return v; - }; - - // There are no known statistical flaws with 13 rounds of Threefry2x32. - // We are conservative and use 20 rounds. - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 1)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 2)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0<int32>(builder, 3)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 4)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 5)); - - return x; -} - -// Returns a tensor of 'shape' random values uniformly distributed in the range -// [minval, maxval) -xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, - const TensorShape& shape, double minval, - double maxval) { - // Split the seed into two 32-bit scalars to form a key. - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - ThreeFry2x32State key = {seed0, seed1}; - const int64 size = shape.num_elements(); - - const int64 half_size = MathUtil::CeilOfRatio<int64>(size, 2); - const bool size_is_odd = (half_size * 2 != size); - - // Fill the generator inputs with unique counter values. - ThreeFry2x32State inputs; - inputs[0] = xla::Iota(builder, xla::S32, half_size); - inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size)); - ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); - - if (size_is_odd) { - outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - - auto bits = - xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); - - // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit - // forces the random bits into the mantissa. - constexpr int kFloatBits = 32; - constexpr int kMantissaBits = 23; - bits = xla::Or( - xla::ShiftRightLogical( - bits, xla::ConstantR0<int32>(builder, kFloatBits - kMantissaBits)), - xla::ConstantR0<int32>(builder, bit_cast<int32>(1.0f))); - auto floats = xla::BitcastConvertType(bits, xla::F32); - - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = xla::Sub(floats, xla::ConstantR0<float>(builder, 1.0f)); - // Multiply and add to shift to the range [minval, maxval). - floats = xla::Mul(floats, xla::ConstantR0<float>(builder, maxval - minval)); - floats = xla::Add(floats, xla::ConstantR0<float>(builder, minval)); - return floats; -} - -} // namespace - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -177,7 +50,17 @@ class StatelessRandomUniformOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, xla::ConstantR0<float>(builder, 0.0), + xla::ConstantR0<float>(builder, 1.0)); + ctx->SetOutput(0, uniform); } private: @@ -206,8 +89,16 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = - RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0<float>(builder, 1.0)); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = @@ -240,10 +131,18 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - xla::XlaBuilder* b = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0<float>(builder, std::numeric_limits<float>::min()), + xla::ConstantR0<float>(builder, 1.0)); - auto uniform = - RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0); ctx->SetOutput(0, TruncatedNormal(uniform)); } @@ -257,4 +156,5 @@ REGISTER_XLA_OP(Name("StatelessTruncatedNormal") .TypeConstraint("Tseed", DT_INT32), StatelessTruncatedNormalOp); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 6933e9a838..ece5a885b5 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -119,6 +119,21 @@ xla_test( ) cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + ":math", + ":numeric", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc new file mode 100644 index 0000000000..299a6ac2b6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +XlaOp RotateLeftS32(XlaOp v, int distance) { + return (v << ConstantR0<int32>(v.builder(), distance)) | + ShiftRightLogical(v, ConstantR0<int32>(v.builder(), 32 - distance)); +} + +using ThreeFry2x32State = std::array<XlaOp, 2>; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { + XlaBuilder* builder = input[0].builder(); + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array<XlaOp, 3> ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = ConstantR0<int32>(builder, 0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = ks[2] ^ key[i]; + } + + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1]; + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [builder](ThreeFry2x32State v, int rotation) { + v[0] = v[0] + v[1]; + v[1] = RotateLeftS32(v[1], rotation); + v[1] = v[0] ^ v[1]; + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0<int32>(builder, 1); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0<int32>(builder, 2); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1] + ConstantR0<int32>(builder, 3); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0<int32>(builder, 4); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0<int32>(builder, 5); + + return x; +} + +} // namespace + +XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, + XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = seeds[0].builder(); + if (shape.element_type() != F32) { + return builder->ReportError(Unimplemented( + "Types other than F32 are not implemented by StatelessRngUniform.")); + } + ThreeFry2x32State key = seeds; + const int64 size = ShapeUtil::ElementsIn(shape); + + const int64 half_size = CeilOfRatio<int64>(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + inputs[0] = Iota(builder, S32, half_size); + inputs[1] = inputs[0] + ConstantR0<int32>(builder, half_size); + ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); + + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = Reshape(ConcatInDim(builder, outputs, 0), + AsInt64Slice(shape.dimensions())); + + // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = ShiftRightLogical( + bits, ConstantR0<int32>(builder, kFloatBits - kMantissaBits)) | + ConstantR0<int32>(builder, tensorflow::bit_cast<int32>(1.0f)); + auto floats = BitcastConvertType(bits, F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = floats - ConstantR0<float>(builder, 1.0f); + // Multiply and add to shift to the range [minval, maxval). + return floats * (maxval - minval) + minval; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h new file mode 100644 index 0000000000..ac86390239 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ + +#include <array> + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). Requires 2 32-bit integer seeds. +// Currently only 'shape's of type F32 are implemented. +XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, + XlaOp minval, XlaOp maxval); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ |