diff options
author | 2018-09-10 14:37:06 -0700 | |
---|---|---|
committer | 2018-09-10 15:04:14 -0700 | |
commit | b828f89263e054bfa7c7a808cab1506834ab906d (patch) | |
tree | e31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/python/data | |
parent | acf0ee82092727afc2067316982407cf5e496f75 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/util/convert_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/data/util/sparse_test.py | 2 |
2 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py index 6a67093e48..89c3afb296 100644 --- a/tensorflow/python/data/util/convert_test.py +++ b/tensorflow/python/data/util/convert_test.py @@ -30,28 +30,28 @@ class ConvertTest(test.TestCase): def testInteger(self): resp = convert.optional_param_to_tensor("foo", 3) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(resp)) def testIntegerDefault(self): resp = convert.optional_param_to_tensor("foo", None) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(resp)) def testStringDefault(self): resp = convert.optional_param_to_tensor("bar", None, "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("default"), sess.run(resp)) def testString(self): resp = convert.optional_param_to_tensor("bar", "value", "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("value"), sess.run(resp)) def testPartialShapeToTensorKnownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([1])))) self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,)))) @@ -60,7 +60,7 @@ class ConvertTest(test.TestCase): constant_op.constant([1], dtype=dtypes.int64)))) def testPartialShapeToTensorUnknownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([None])))) self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( @@ -84,7 +84,7 @@ class ConvertTest(test.TestCase): convert.partial_shape_to_tensor(constant_op.constant([1., 1.])) def testPartialShapeToTensorMultipleDimensions(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([3, 6])))) self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( @@ -113,7 +113,7 @@ class ConvertTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64)))) def testPartialShapeToTensorScalar(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([])))) self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(()))) diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py index d49b3ff34b..056b32480f 100644 --- a/tensorflow/python/data/util/sparse_test.py +++ b/tensorflow/python/data/util/sparse_test.py @@ -291,7 +291,7 @@ class SparseTest(test.TestCase): self.assertEqual(a, b) return self.assertTrue(isinstance(b, sparse_tensor.SparseTensor)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(a.eval().indices, b.eval().indices) self.assertAllEqual(a.eval().values, b.eval().values) self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape) |