diff options
author | 2018-06-11 11:43:45 -0700 | |
---|---|---|
committer | 2018-06-11 11:46:49 -0700 | |
commit | c73cd1afce146aa2559cafa4ac72fe638db43860 (patch) | |
tree | fff37b8d1ed44079f3410349b10348a761ac6040 | |
parent | 81682566acf8ea5b5691a9e36d7740953e3c7ef7 (diff) |
[TF:XLA] Small performance tweaks for tf.random_shuffle, but still too slow.
PiperOrigin-RevId: 200086551
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/random_ops.cc | 30 |
1 files changed, 6 insertions, 24 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index ebac5c4396..105be38fe2 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -76,32 +76,14 @@ class RandomShuffleOp : public XlaOpKernel { ctx->SetOutput(0, input); } else { // Generate the random swaps for the indices. - auto zero = builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(xla::S32)), - gtl::ArraySlice<int64>({n})); - auto n_maxval = builder->Broadcast(builder->ConstantR0<int32>(n), - gtl::ArraySlice<int64>({n})); auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); - auto swaps = builder->RngUniform(zero, n_maxval, swaps_shape); + auto swaps = + builder->RngUniform(builder->ConstantR0<int32>(0), + builder->ConstantR0<int32>(n), swaps_shape); // Generate range(n) as the initial value for the indices to be swapped. - auto index_init_body_fn = [&](xla::XlaOp i, - gtl::ArraySlice<xla::XlaOp> loop_vars, - xla::XlaBuilder* builder) - -> xla::StatusOr<std::vector<xla::XlaOp>> { - auto indices = loop_vars[0]; - i = builder->Reshape(i, {}, {1}); - // indices[i] = i - indices = builder->DynamicUpdateSlice(indices, i, i); - return std::vector<xla::XlaOp>{indices}; - }; - // for i in range(n): - xla::XlaOp index_zeros = Zeros(builder, swaps_shape); - auto index_init_loop_result = - XlaForEachIndex(n, xla::S32, index_init_body_fn, {index_zeros}, - "index_init_loop", builder) - .ValueOrDie(); - auto indices = index_init_loop_result[0]; + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); // Swap the indices at i and swaps[i]. auto swap_body_fn = [&](xla::XlaOp i, @@ -110,7 +92,7 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr<std::vector<xla::XlaOp>> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = builder->Reshape(i, {}, {1}); + i = builder->Reshape(i, {1}); // temp = indices[i] auto temp = builder->DynamicSlice(indices, i, {1}); // swap_index = swaps[i] |