aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/operators/py_builtins_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/operators/py_builtins_test.py')
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index a021263ffa..c94a918d5a 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
def test_abs(self):
self.assertEqual(py_builtins.abs_(-1), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(sess.run(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
@@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(sess.run(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
@@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(sess.run(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
@@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
def test_len(self):
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
@@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
py_builtins.len_(constant_op.constant(1))
def test_len_dynamic_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
t = py_builtins.len_(p)
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
@@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally:
@@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
@@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
def test_range_tensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(sess.run(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
@@ -126,6 +126,13 @@ class PyBuiltinsTest(test.TestCase):
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(sess.run(r), [2, 1])
+ def test_range_tensor_empty_range(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(-3))
+ self.assertAllEqual(sess.run(r), [])
+ r = py_builtins.range_(5, constant_op.constant(2))
+ self.assertAllEqual(sess.run(r), [])
+
if __name__ == '__main__':
test.main()