aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-06 12:39:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-06 12:43:30 -0700
commit8f2e5f0b4a0221ca1573a40a68077326a32c9bc0 (patch)
treefff8d5d0285fab2169902ced613879485b692061 /tensorflow
parent8b460629e51356485d4da80d81f22e5911a64788 (diff)
[TF:XLA] Add a implementation of RandomShuffle.
PiperOrigin-RevId: 199511721
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/BUILD2
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py38
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc92
3 files changed, 126 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index b51c11bf6e..e6c92f9720 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -545,7 +545,9 @@ tf_xla_py_test(
],
deps = [
":xla_test",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 70be22936a..f13dff9620 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest
@@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase):
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
- (not np.array_equal(z, w)) or
- (not np.array_equal(y, w)))
+ (not np.array_equal(z, w)) or (not np.array_equal(y, w)))
def testRandomUniformIsNotConstant(self):
+
def rng(dtype):
- return random_ops.random_uniform(shape=[2], dtype=dtype,
- maxval=1000000)
+ return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
@@ -70,13 +72,14 @@ class RandomOpsTest(XLATestCase):
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
- x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
- maxval=33)
+ x = random_ops.random_uniform(
+ shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
def testTruncatedNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
@@ -94,6 +97,29 @@ class RandomOpsTest(XLATestCase):
self.assertTrue((y >= -2).sum() == count)
self.assertTrue((y <= 2).sum() == count)
+ def testShuffle1d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = math_ops.range(20)
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = range(20)
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(set(result), set(expected))
+
+ def testShuffle2d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = array_ops.diag(math_ops.range(20))
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = np.diag(range(20)).flatten()
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(len(result.flatten()), len(expected))
+ self.assertAllEqual(set(result.flatten()), set(expected))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 39149d56ad..ebac5c4396 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -17,6 +17,8 @@ limitations under the License.
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
+#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -56,6 +58,96 @@ class RandomUniformOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
RandomUniformOp);
+class RandomShuffleOp : public XlaOpKernel {
+ public:
+ explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ TensorShape input_shape = ctx->InputShape(0);
+ const int64 n = input_shape.dim_size(0);
+ int64 num_elements = 1;
+ for (tensorflow::TensorShapeDim dimension : input_shape) {
+ num_elements *= dimension.size;
+ }
+ if (num_elements <= 1 || n <= 1) {
+ // No shuffling is required, so copy input directly to output
+ ctx->SetOutput(0, input);
+ } else {
+ // Generate the random swaps for the indices.
+ auto zero = builder->Broadcast(
+ builder->ConstantLiteral(xla::Literal::Zero(xla::S32)),
+ gtl::ArraySlice<int64>({n}));
+ auto n_maxval = builder->Broadcast(builder->ConstantR0<int32>(n),
+ gtl::ArraySlice<int64>({n}));
+ auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
+ auto swaps = builder->RngUniform(zero, n_maxval, swaps_shape);
+
+ // Generate range(n) as the initial value for the indices to be swapped.
+ auto index_init_body_fn = [&](xla::XlaOp i,
+ gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto indices = loop_vars[0];
+ i = builder->Reshape(i, {}, {1});
+ // indices[i] = i
+ indices = builder->DynamicUpdateSlice(indices, i, i);
+ return std::vector<xla::XlaOp>{indices};
+ };
+ // for i in range(n):
+ xla::XlaOp index_zeros = Zeros(builder, swaps_shape);
+ auto index_init_loop_result =
+ XlaForEachIndex(n, xla::S32, index_init_body_fn, {index_zeros},
+ "index_init_loop", builder)
+ .ValueOrDie();
+ auto indices = index_init_loop_result[0];
+
+ // Swap the indices at i and swaps[i].
+ auto swap_body_fn = [&](xla::XlaOp i,
+ gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto swaps = loop_vars[0];
+ auto indices = loop_vars[1];
+ i = builder->Reshape(i, {}, {1});
+ // temp = indices[i]
+ auto temp = builder->DynamicSlice(indices, i, {1});
+ // swap_index = swaps[i]
+ auto swap_index = builder->DynamicSlice(swaps, i, {1});
+ // swap_value = indices[swaps[i]]
+ auto swap_value = builder->DynamicSlice(indices, swap_index, {1});
+ // indices[i] = indices[swaps[i]]
+ indices = builder->DynamicUpdateSlice(indices, swap_value, i);
+ // indices[swaps[i]] = temp
+ indices = builder->DynamicUpdateSlice(indices, temp, swap_index);
+ return std::vector<xla::XlaOp>{swaps, indices};
+ };
+ // for i in range(n):
+ auto swap_loop_result =
+ XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
+ .ValueOrDie();
+ auto swapped_indices = swap_loop_result[1];
+
+ // Gather the data using the swapped indices as the shuffled order.
+ auto indices_tensor_shape = TensorShape({n});
+ DataType type = ctx->expected_output_dtype(0);
+ xla::XlaOp gather;
+ OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
+ indices_tensor_shape,
+ /*axis=*/0, /*indices_are_nd=*/false, type,
+ DT_INT32, builder, &gather));
+ ctx->SetOutput(0, gather);
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp);
+};
+
+REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp);
+
class RandomUniformIntOp : public XlaOpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}