diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/random_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/random_ops.cc | 163 |
1 files changed, 114 insertions, 49 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 51f2cdc9f4..607cad798a 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -73,57 +74,121 @@ class RandomShuffleOp : public XlaOpKernel { for (tensorflow::TensorShapeDim dimension : input_shape) { num_elements *= dimension.size; } + if (num_elements <= 1 || n <= 1) { // No shuffling is required, so copy input directly to output ctx->SetOutput(0, input); - } else { - // Generate the random swaps for the indices. - auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); - auto swaps = - xla::RngUniform(xla::ConstantR0<int32>(builder, 0), - xla::ConstantR0<int32>(builder, n), swaps_shape); - - // Generate range(n) as the initial value for the indices to be swapped. - xla::XlaOp indices; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); - - // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, - gtl::ArraySlice<xla::XlaOp> loop_vars, - xla::XlaBuilder* builder) - -> xla::StatusOr<std::vector<xla::XlaOp>> { - auto swaps = loop_vars[0]; - auto indices = loop_vars[1]; - i = xla::Reshape(i, {1}); - // temp = indices[i] - auto temp = xla::DynamicSlice(indices, i, {1}); - // swap_index = swaps[i] - auto swap_index = xla::DynamicSlice(swaps, i, {1}); - // swap_value = indices[swaps[i]] - auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); - // indices[i] = indices[swaps[i]] - indices = xla::DynamicUpdateSlice(indices, swap_value, i); - // indices[swaps[i]] = temp - indices = xla::DynamicUpdateSlice(indices, temp, swap_index); - return std::vector<xla::XlaOp>{swaps, indices}; - }; - // for i in range(n): - auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) - .ValueOrDie(); - auto swapped_indices = swap_loop_result[1]; - - // Gather the data using the swapped indices as the shuffled order. - auto indices_tensor_shape = TensorShape({n}); - DataType type = ctx->expected_output_dtype(0); - xla::XlaOp gather; - OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, - indices_tensor_shape, - /*axis=*/0, /*indices_are_nd=*/false, type, - DT_INT32, builder, &gather)); - ctx->SetOutput(0, gather); + return; + } + + if (input_shape.dims() == 1) { + // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates + // algorithm. Fisher-Yates is simple to implement and correct, but not + // easily parallelizable. For a sufficiently parallel architecture, it is + // faster to sort many times, than Fisher-Yates shuffle once. + + // Shuffle values by assigning each value a random key and sorting the + // keys. Keys can collide causing detectable patterns in the shuffled + // output. Collisions translates into more ascending sub-sequences in the + // shuffled output than would be expected by chance. To avoid collisions, + // the number of possible key values must be sufficiently large. + + // How are more than 2^32 keys created? In each loop iteration, the + // algorithm sorts by random keys. Conceptually, the earlier iterations + // are sorting on the lower-order bits of larger keys that are never + // actually assembled. + + // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is + // the number of possible keys and n is the number of values. If d = n^2, + // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit + // as n goes to infinity is zero. + + // This implementation ensures that the key-space is greater than or equal + // to the cube of the number of values. The risk of collisions can be + // further reduced by increasing Exponent at the expense of + // performance. + + // For Exponent = 2, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is + // about 1/2. + + // For Exponent = 3, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is + // about 1/3255. + + // For Exponent = 4, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is + // about 1/132622. + constexpr int Exponent = 3; + const int rounds = static_cast<int>( + std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max))); + + const xla::Shape key_shape = + xla::ShapeUtil::MakeShape(xla::U32, {num_elements}); + xla::XlaOp zero = xla::ConstantR0(builder, 0U); + + // Unfortunately, xla::RngUniform gives values in the half open interval + // rather than the closed interval, so instead of 2^32 possible keys there + // are only 2^32 - 1 (kuint32max). + xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max); + + xla::XlaOp curr = input; + for (int i = 0; i < rounds; ++i) { + xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); + xla::XlaOp sorted = xla::Sort(keys, curr); + curr = xla::GetTupleElement(sorted, 1); + } + + ctx->SetOutput(0, curr); + return; } + + // The Fisher-Yates algorithm. + + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + xla::RngUniform(xla::ConstantR0<int32>(builder, 0), + xla::ConstantR0<int32>(builder, n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr<std::vector<xla::XlaOp>> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = xla::Reshape(i, {1}); + // temp = indices[i] + auto temp = xla::DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = xla::DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = xla::DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + return std::vector<xla::XlaOp>{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); } private: @@ -211,7 +276,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaOp min_positive = XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min()); auto uniform = xla::RngUniform(min_positive, one, xla_shape); - ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); + ctx->SetOutput(0, TruncatedNormal(uniform)); } }; @@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal") .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); -} // anonymous namespace +} // namespace } // namespace tensorflow |