diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-06 12:39:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-06 12:43:30 -0700 |
commit | 8f2e5f0b4a0221ca1573a40a68077326a32c9bc0 (patch) | |
tree | fff8d5d0285fab2169902ced613879485b692061 | |
parent | 8b460629e51356485d4da80d81f22e5911a64788 (diff) |
[TF:XLA] Add a implementation of RandomShuffle.
PiperOrigin-RevId: 199511721
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 38 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/random_ops.cc | 92 |
3 files changed, 126 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b51c11bf6e..e6c92f9720 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -545,7 +545,9 @@ tf_xla_py_test( ], deps = [ ":xla_test", + "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 70be22936a..f13dff9620 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest @@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,13 +72,14 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) def testTruncatedNormalIsNotConstant(self): + def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) @@ -94,6 +97,29 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == count) self.assertTrue((y <= 2).sum() == count) + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 39149d56ad..ebac5c4396 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,8 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -56,6 +58,96 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto zero = builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(xla::S32)), + gtl::ArraySlice<int64>({n})); + auto n_maxval = builder->Broadcast(builder->ConstantR0<int32>(n), + gtl::ArraySlice<int64>({n})); + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = builder->RngUniform(zero, n_maxval, swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + auto index_init_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice<xla::XlaOp> loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr<std::vector<xla::XlaOp>> { + auto indices = loop_vars[0]; + i = builder->Reshape(i, {}, {1}); + // indices[i] = i + indices = builder->DynamicUpdateSlice(indices, i, i); + return std::vector<xla::XlaOp>{indices}; + }; + // for i in range(n): + xla::XlaOp index_zeros = Zeros(builder, swaps_shape); + auto index_init_loop_result = + XlaForEachIndex(n, xla::S32, index_init_body_fn, {index_zeros}, + "index_init_loop", builder) + .ValueOrDie(); + auto indices = index_init_loop_result[0]; + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice<xla::XlaOp> loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr<std::vector<xla::XlaOp>> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = builder->Reshape(i, {}, {1}); + // temp = indices[i] + auto temp = builder->DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = builder->DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = builder->DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + return std::vector<xla::XlaOp>{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} |