diff options
author | 2018-05-18 12:17:05 -0700 | |
---|---|---|
committer | 2018-05-18 12:20:34 -0700 | |
commit | 514bb4f3a630612fd6f6aaf62d9bbc0e4c72d0ff (patch) | |
tree | 9f5a144f36c8967f8208bdefbad4051c52737f84 /tensorflow/contrib/distributions | |
parent | 77871562537ff726473ae9b69fee658f32738f63 (diff) |
Enable `SeedStream` construction from other `SeedStream` instances.
PiperOrigin-RevId: 197182686
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/seed_stream.py | 2 |
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 9680573317..b91a610acf 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349688..cf505ac627 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ class SeedStream(object): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0 |