aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-10 13:40:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 13:48:35 -0700
commitba76f5ca8f722d8c66e4687d8a2161d858e9b407 (patch)
tree11320ba2e0de313af3bb72a568f210889393bd25
parent79c95b754c241afe3dab741a895ffdbb9646bd65 (diff)
Use global counter for XLA rng seed
PiperOrigin-RevId: 208260479
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/tests/eager_test.py9
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py14
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc26
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h3
8 files changed, 48 insertions, 12 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 15f9ba217f..55b98da472 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -160,6 +160,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 00a6f4075f..8f78c110cb 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -16,6 +16,7 @@ cc_library(
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index b313d48011..37a2f3b5ac 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@@ -199,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_stream(stream);
run_options.set_allocator(xla_allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(ctx->step_id());
+ run_options.set_rng_seed(GetXLARandomSeed());
Env* env = Env::Default();
auto start_time = env->NowMicros();
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index d288d37bc7..f65f89ebf5 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -71,7 +72,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(ctx->step_id());
+ run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
executable->Run(launch_context.arguments(), run_options);
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 422f36d43b..ff097f80f1 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase):
with self.test_scope():
self.assertAllEqual(2, array_ops.identity(2))
+ def testRandomOps(self):
+ with self.test_scope():
+ tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
+ row0 = tensor[0].numpy()
+ row1 = tensor[1].numpy()
+ # It should be very unlikely to rng to generate two equal rows.
+ self.assertFalse((row0 == row1).all())
+
def testIdentityOnVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(True)
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index cc0e9b2f98..8c4e16e4e0 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -101,7 +101,7 @@ class RandomOpsTest(xla_test.XLATestCase):
for dtype in [dtypes.float32]:
with self.test_session() as sess:
with self.test_scope():
- x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42)
+ x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x)
def normal_cdf(x):
@@ -130,24 +130,18 @@ class RandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- atol = 2e-4
- if self.device in ["XLA_GPU", "XLA_CPU"]:
- atol = 2.2e-4
- self.assertAllClose(actual_mean, expected_mean, atol=atol)
+ self.assertAllClose(actual_mean, expected_mean, atol=2e-3)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
- self.assertAllClose(actual_median, expected_median, atol=1e-3)
+ self.assertAllClose(actual_median, expected_median, atol=1e-2)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
- rtol = 1e-3
- if self.device in ["XLA_GPU", "XLA_CPU"]:
- rtol = 4e-4
- self.assertAllClose(actual_variance, expected_variance, rtol=rtol)
+ self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
# TODO(b/26783907): this test requires the CPU backend to implement sort.
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 9203e8d9e6..0e07485d18 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue>
+#include <random>
#include <set>
#include <unordered_map>
@@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
}
}
+namespace {
+uint32 InitialRandomSeed() {
+ // Support plumbing the TF seed through to XLA is being worked on.
+ // If a user wants deterministic behavior, their best option
+ // is to start with a known checkpoint. This also handles issues when
+ // multiple random calls can be invoked in any order by TF executor.
+ // Another option is to use stateless random ops. They have much cleaner
+ // semantics.
+ // If a user really wants to set a deterministic seed for XLA-based
+ // devices, this is the place to do it.
+ std::random_device rd;
+ // Make the starting value odd.
+ return rd() | 1;
+}
+} // namespace
+
+uint32 GetXLARandomSeed() {
+ // We initialize counter with an odd number and increment it by two
+ // everytime. This ensures that it will never be zero, even
+ // after an overflow. When seeded with zero, some XLA backends
+ // can return all zeros instead of random numbers.
+ static std::atomic<uint32> counter(InitialRandomSeed());
+ return counter.fetch_add(2);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 745beb39c1..33620ef810 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
KernelDef* kdef);
+// Returns the next random seed to use for seeding xla rng.
+uint32 GetXLARandomSeed();
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_