aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/random_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/random_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc163
1 files changed, 114 insertions, 49 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 51f2cdc9f4..607cad798a 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -73,57 +74,121 @@ class RandomShuffleOp : public XlaOpKernel {
for (tensorflow::TensorShapeDim dimension : input_shape) {
num_elements *= dimension.size;
}
+
if (num_elements <= 1 || n <= 1) {
// No shuffling is required, so copy input directly to output
ctx->SetOutput(0, input);
- } else {
- // Generate the random swaps for the indices.
- auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
- auto swaps =
- xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
- xla::ConstantR0<int32>(builder, n), swaps_shape);
-
- // Generate range(n) as the initial value for the indices to be swapped.
- xla::XlaOp indices;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices));
-
- // Swap the indices at i and swaps[i].
- auto swap_body_fn = [&](xla::XlaOp i,
- gtl::ArraySlice<xla::XlaOp> loop_vars,
- xla::XlaBuilder* builder)
- -> xla::StatusOr<std::vector<xla::XlaOp>> {
- auto swaps = loop_vars[0];
- auto indices = loop_vars[1];
- i = xla::Reshape(i, {1});
- // temp = indices[i]
- auto temp = xla::DynamicSlice(indices, i, {1});
- // swap_index = swaps[i]
- auto swap_index = xla::DynamicSlice(swaps, i, {1});
- // swap_value = indices[swaps[i]]
- auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
- // indices[i] = indices[swaps[i]]
- indices = xla::DynamicUpdateSlice(indices, swap_value, i);
- // indices[swaps[i]] = temp
- indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
- return std::vector<xla::XlaOp>{swaps, indices};
- };
- // for i in range(n):
- auto swap_loop_result =
- XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
- "indices_swap_loop", builder)
- .ValueOrDie();
- auto swapped_indices = swap_loop_result[1];
-
- // Gather the data using the swapped indices as the shuffled order.
- auto indices_tensor_shape = TensorShape({n});
- DataType type = ctx->expected_output_dtype(0);
- xla::XlaOp gather;
- OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
- indices_tensor_shape,
- /*axis=*/0, /*indices_are_nd=*/false, type,
- DT_INT32, builder, &gather));
- ctx->SetOutput(0, gather);
+ return;
+ }
+
+ if (input_shape.dims() == 1) {
+ // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
+ // algorithm. Fisher-Yates is simple to implement and correct, but not
+ // easily parallelizable. For a sufficiently parallel architecture, it is
+ // faster to sort many times, than Fisher-Yates shuffle once.
+
+ // Shuffle values by assigning each value a random key and sorting the
+ // keys. Keys can collide causing detectable patterns in the shuffled
+ // output. Collisions translates into more ascending sub-sequences in the
+ // shuffled output than would be expected by chance. To avoid collisions,
+ // the number of possible key values must be sufficiently large.
+
+ // How are more than 2^32 keys created? In each loop iteration, the
+ // algorithm sorts by random keys. Conceptually, the earlier iterations
+ // are sorting on the lower-order bits of larger keys that are never
+ // actually assembled.
+
+ // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
+ // the number of possible keys and n is the number of values. If d = n^2,
+ // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
+ // as n goes to infinity is zero.
+
+ // This implementation ensures that the key-space is greater than or equal
+ // to the cube of the number of values. The risk of collisions can be
+ // further reduced by increasing Exponent at the expense of
+ // performance.
+
+ // For Exponent = 2, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
+ // about 1/2.
+
+ // For Exponent = 3, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
+ // about 1/3255.
+
+ // For Exponent = 4, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
+ // about 1/132622.
+ constexpr int Exponent = 3;
+ const int rounds = static_cast<int>(
+ std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));
+
+ const xla::Shape key_shape =
+ xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
+ xla::XlaOp zero = xla::ConstantR0(builder, 0U);
+
+ // Unfortunately, xla::RngUniform gives values in the half open interval
+ // rather than the closed interval, so instead of 2^32 possible keys there
+ // are only 2^32 - 1 (kuint32max).
+ xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);
+
+ xla::XlaOp curr = input;
+ for (int i = 0; i < rounds; ++i) {
+ xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
+ xla::XlaOp sorted = xla::Sort(keys, curr);
+ curr = xla::GetTupleElement(sorted, 1);
+ }
+
+ ctx->SetOutput(0, curr);
+ return;
}
+
+ // The Fisher-Yates algorithm.
+
+ // Generate the random swaps for the indices.
+ auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
+ auto swaps =
+ xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
+ xla::ConstantR0<int32>(builder, n), swaps_shape);
+
+ // Generate range(n) as the initial value for the indices to be swapped.
+ xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
+
+ // Swap the indices at i and swaps[i].
+ auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto swaps = loop_vars[0];
+ auto indices = loop_vars[1];
+ i = xla::Reshape(i, {1});
+ // temp = indices[i]
+ auto temp = xla::DynamicSlice(indices, i, {1});
+ // swap_index = swaps[i]
+ auto swap_index = xla::DynamicSlice(swaps, i, {1});
+ // swap_value = indices[swaps[i]]
+ auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
+ // indices[i] = indices[swaps[i]]
+ indices = xla::DynamicUpdateSlice(indices, swap_value, i);
+ // indices[swaps[i]] = temp
+ indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
+ return std::vector<xla::XlaOp>{swaps, indices};
+ };
+ // for i in range(n):
+ auto swap_loop_result =
+ XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
+ .ValueOrDie();
+ auto swapped_indices = swap_loop_result[1];
+
+ // Gather the data using the swapped indices as the shuffled order.
+ auto indices_tensor_shape = TensorShape({n});
+ DataType type = ctx->expected_output_dtype(0);
+ xla::XlaOp gather;
+ OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
+ indices_tensor_shape,
+ /*axis=*/0, /*indices_are_nd=*/false, type,
+ DT_INT32, builder, &gather));
+ ctx->SetOutput(0, gather);
}
private:
@@ -211,7 +276,7 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::XlaOp min_positive =
XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
auto uniform = xla::RngUniform(min_positive, one, xla_shape);
- ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
+ ctx->SetOutput(0, TruncatedNormal(uniform));
}
};
@@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal")
.TypeConstraint("dtype", DT_FLOAT),
TruncatedNormalOp);
-} // anonymous namespace
+} // namespace
} // namespace tensorflow