aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/util/convert_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/util/convert_test.py')
-rw-r--r--tensorflow/python/data/util/convert_test.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 6a67093e48..89c3afb296 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -30,28 +30,28 @@ class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
def testPartialShapeToTensorKnownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([1]))))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
@@ -60,7 +60,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([1], dtype=dtypes.int64))))
def testPartialShapeToTensorUnknownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None]))))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
@@ -84,7 +84,7 @@ class ConvertTest(test.TestCase):
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
def testPartialShapeToTensorMultipleDimensions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, 6]))))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
@@ -113,7 +113,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64))))
def testPartialShapeToTensorScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([]))))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))