aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-05-18 12:17:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-18 12:20:34 -0700
commit514bb4f3a630612fd6f6aaf62d9bbc0e4c72d0ff (patch)
tree9f5a144f36c8967f8208bdefbad4051c52737f84 /tensorflow/contrib/distributions
parent77871562537ff726473ae9b69fee658f32738f63 (diff)
Enable `SeedStream` construction from other `SeedStream` instances.
PiperOrigin-RevId: 197182686
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/seed_stream.py2
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
index 9680573317..b91a610acf 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
@@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase):
self.assertAllUnique(
outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
+ def testInitFromOtherSeedStream(self):
+ strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm2 = seed_stream.SeedStream(strm1, salt="salt")
+ strm3 = seed_stream.SeedStream(strm1, salt="another salt")
+ out1 = [strm1() for _ in range(50)]
+ out2 = [strm2() for _ in range(50)]
+ out3 = [strm3() for _ in range(50)]
+ self.assertAllEqual(out1, out2)
+ self.assertAllUnique(out1 + out3)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py
index 056d349688..cf505ac627 100644
--- a/tensorflow/contrib/distributions/python/ops/seed_stream.py
+++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py
@@ -169,7 +169,7 @@ class SeedStream(object):
and TensorFlow Probability code base. See class docstring for
rationale.
"""
- self._seed = seed
+ self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed
self._salt = salt
self._counter = 0