aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 11:43:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 11:46:49 -0700
commitc73cd1afce146aa2559cafa4ac72fe638db43860 (patch)
treefff37b8d1ed44079f3410349b10348a761ac6040
parent81682566acf8ea5b5691a9e36d7740953e3c7ef7 (diff)
[TF:XLA] Small performance tweaks for tf.random_shuffle, but still too slow.
PiperOrigin-RevId: 200086551
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc30
1 files changed, 6 insertions, 24 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index ebac5c4396..105be38fe2 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -76,32 +76,14 @@ class RandomShuffleOp : public XlaOpKernel {
ctx->SetOutput(0, input);
} else {
// Generate the random swaps for the indices.
- auto zero = builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(xla::S32)),
- gtl::ArraySlice<int64>({n}));
- auto n_maxval = builder->Broadcast(builder->ConstantR0<int32>(n),
- gtl::ArraySlice<int64>({n}));
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
- auto swaps = builder->RngUniform(zero, n_maxval, swaps_shape);
+ auto swaps =
+ builder->RngUniform(builder->ConstantR0<int32>(0),
+ builder->ConstantR0<int32>(n), swaps_shape);
// Generate range(n) as the initial value for the indices to be swapped.
- auto index_init_body_fn = [&](xla::XlaOp i,
- gtl::ArraySlice<xla::XlaOp> loop_vars,
- xla::XlaBuilder* builder)
- -> xla::StatusOr<std::vector<xla::XlaOp>> {
- auto indices = loop_vars[0];
- i = builder->Reshape(i, {}, {1});
- // indices[i] = i
- indices = builder->DynamicUpdateSlice(indices, i, i);
- return std::vector<xla::XlaOp>{indices};
- };
- // for i in range(n):
- xla::XlaOp index_zeros = Zeros(builder, swaps_shape);
- auto index_init_loop_result =
- XlaForEachIndex(n, xla::S32, index_init_body_fn, {index_zeros},
- "index_init_loop", builder)
- .ValueOrDie();
- auto indices = index_init_loop_result[0];
+ xla::XlaOp indices;
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices));
// Swap the indices at i and swaps[i].
auto swap_body_fn = [&](xla::XlaOp i,
@@ -110,7 +92,7 @@ class RandomShuffleOp : public XlaOpKernel {
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
auto indices = loop_vars[1];
- i = builder->Reshape(i, {}, {1});
+ i = builder->Reshape(i, {1});
// temp = indices[i]
auto temp = builder->DynamicSlice(indices, i, {1});
// swap_index = swaps[i]