diff options
Diffstat (limited to 'tensorflow/compiler/tests/random_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 2e71b00ba6..14c5e7a975 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,7 +22,7 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): @@ -140,10 +140,10 @@ class RandomOpsTest(XLATestCase): def testShuffle1d(self): with self.test_session() as sess: with self.test_scope(): - x = math_ops.range(20) + x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) result = sess.run(shuffle) - expected = range(20) + expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. self.assertAllEqual(set(result), set(expected)) |