aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/stateless
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-04-17 13:30:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-17 14:50:08 -0700
commitcc45456e4ad0eff16127d1727d0cf48afb71ca0e (patch)
tree6c9af119ffbc40b789dc3e6f91da842ad7e87b4b /tensorflow/contrib/stateless
parent3d0380476b16d8180cd89c34b1b9e7d6e7275e7f (diff)
Add stateless random ops for custom control of seeding
The new ops are in tf.contrib.stateless. They reuse the same Philox kernels as the stateful random ops, but this may change in future. RELNOTES: Add tf.contrib.stateless for random ops with custom seed control. Change: 153388998
Diffstat (limited to 'tensorflow/contrib/stateless')
-rw-r--r--tensorflow/contrib/stateless/BUILD50
-rw-r--r--tensorflow/contrib/stateless/__init__.py38
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py84
3 files changed, 172 insertions, 0 deletions
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
new file mode 100644
index 0000000000..1d9c1ffa50
--- /dev/null
+++ b/tensorflow/contrib/stateless/BUILD
@@ -0,0 +1,50 @@
+# Stateless random ops
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+tf_gen_op_wrapper_py(
+ name = "stateless_random_ops",
+ out = "gen_stateless_random_ops.py", # cmake chokes without this
+ deps = ["//tensorflow/core:stateless_random_ops_op_lib"],
+)
+
+py_library(
+ name = "stateless",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":stateless_random_ops",
+ ],
+)
+
+cuda_py_test(
+ name = "stateless_random_ops_test",
+ srcs = ["python/kernel_tests/stateless_random_ops_test.py"],
+ additional_deps = [
+ ":stateless",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
new file mode 100644
index 0000000000..82e5d36ce4
--- /dev/null
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -0,0 +1,38 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stateless random ops which take seed as a tensor input.
+
+Instead of taking `seed` as an attr which initializes a mutable state within
+the op, these random ops take `seed` as an input, and the random numbers are
+a deterministic function of `shape` and `seed`.
+
+WARNING: These ops are in contrib, and are not stable. They should be
+consistent across multiple runs on the same hardware, but only for the same
+version of the code.
+
+@@stateless_random_uniform
+@@stateless_random_normal
+@@stateless_truncated_normal
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.stateless.gen_stateless_random_ops import *
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
new file mode 100644
index 0000000000..9a36bdc2f9
--- /dev/null
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -0,0 +1,84 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for stateless random ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib import stateless
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform),
+ (stateless.stateless_random_normal, random_ops.random_normal),
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal)]
+
+
+def invert_philox(key, value):
+ """Invert the Philox bijection."""
+ key = np.array(key, dtype=np.uint32)
+ value = np.array(value, dtype=np.uint32)
+ step = np.array([0x9E3779B9, 0xBB67AE85], dtype=np.uint32)
+ for n in range(10)[::-1]:
+ key0, key1 = key + n * step
+ v0 = value[3] * 0x991a7cdb & 0xffffffff
+ v2 = value[1] * 0x6d7cae67 & 0xffffffff
+ hi0 = v0 * 0xD2511F53 >> 32
+ hi1 = v2 * 0xCD9E8D57 >> 32
+ v1 = hi1 ^ value[0] ^ key0
+ v3 = hi0 ^ value[2] ^ key1
+ value = v0, v1, v2, v3
+ return np.array(value)
+
+
+class StatelessOpsTest(test.TestCase):
+
+ def testMatchStateful(self):
+ # Stateless ops should be the same as stateful ops on the first call
+ # after seed scrambling.
+ key = 0x3ec8f720, 0x02461e29
+ for seed in (7, 17), (11, 5), (2, 3):
+ preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
+ preseed = preseed[::2] | preseed[1::2] << 32
+ random_seed.set_random_seed(seed[0])
+ with self.test_session(use_gpu=True):
+ for stateless_op, stateful_op in CASES:
+ for shape in (), (3,), (2, 5):
+ stateful = stateful_op(shape, seed=seed[1])
+ pure = stateless_op(shape, seed=preseed)
+ self.assertAllEqual(stateful.eval(), pure.eval())
+
+ def testDeterminism(self):
+ # Stateless values should be equal iff the seeds are equal (roughly)
+ with self.test_session(use_gpu=True):
+ seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
+ seeds = [(x, y) for x in range(5) for y in range(5)] * 3
+ for stateless_op, _ in CASES:
+ for shape in (), (3,), (2, 5):
+ pure = stateless_op(shape, seed=seed_t)
+ values = [(seed, pure.eval(feed_dict={seed_t: seed}))
+ for seed in seeds]
+ for s0, v0 in values:
+ for s1, v1 in values:
+ self.assertEqual(s0 == s1, np.all(v0 == v1))
+
+
+if __name__ == '__main__':
+ test.main()