aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/random_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/random_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 2e71b00ba6..14c5e7a975 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -22,7 +22,7 @@ import math
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import googletest
-class RandomOpsTest(XLATestCase):
+class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
@@ -140,10 +140,10 @@ class RandomOpsTest(XLATestCase):
def testShuffle1d(self):
with self.test_session() as sess:
with self.test_scope():
- x = math_ops.range(20)
+ x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
- expected = range(20)
+ expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
self.assertAllEqual(set(result), set(expected))