diff options
Diffstat (limited to 'tensorflow/python/data/util/convert_test.py')
-rw-r--r-- | tensorflow/python/data/util/convert_test.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py index 6a67093e48..89c3afb296 100644 --- a/tensorflow/python/data/util/convert_test.py +++ b/tensorflow/python/data/util/convert_test.py @@ -30,28 +30,28 @@ class ConvertTest(test.TestCase): def testInteger(self): resp = convert.optional_param_to_tensor("foo", 3) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(resp)) def testIntegerDefault(self): resp = convert.optional_param_to_tensor("foo", None) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(resp)) def testStringDefault(self): resp = convert.optional_param_to_tensor("bar", None, "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("default"), sess.run(resp)) def testString(self): resp = convert.optional_param_to_tensor("bar", "value", "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("value"), sess.run(resp)) def testPartialShapeToTensorKnownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([1])))) self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,)))) @@ -60,7 +60,7 @@ class ConvertTest(test.TestCase): constant_op.constant([1], dtype=dtypes.int64)))) def testPartialShapeToTensorUnknownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([None])))) self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( @@ -84,7 +84,7 @@ class ConvertTest(test.TestCase): convert.partial_shape_to_tensor(constant_op.constant([1., 1.])) def testPartialShapeToTensorMultipleDimensions(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([3, 6])))) self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( @@ -113,7 +113,7 @@ class ConvertTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64)))) def testPartialShapeToTensorScalar(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([])))) self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(()))) |