aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/parsing_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/parsing_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 7dff4501cc..71d8b60d3c 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -937,7 +937,7 @@ class ParseExampleTest(test.TestCase):
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1054,7 +1054,7 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values = expected_feat_list_values or {}
expected_length_values = expected_length_values or {}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1606,7 +1606,7 @@ class ParseSequenceExampleTest(test.TestCase):
class DecodeJSONExampleTest(test.TestCase):
def _testRoundTrip(self, examples):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
examples = np.array(examples, dtype=np.object)
json_tensor = constant_op.constant(
@@ -1696,7 +1696,7 @@ class DecodeJSONExampleTest(test.TestCase):
])
def testInvalidSyntax(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
json_tensor = constant_op.constant(["{]"])
binary_tensor = parsing_ops.decode_json_example(json_tensor)
with self.assertRaisesOpError("Error while parsing JSON"):
@@ -1706,7 +1706,7 @@ class DecodeJSONExampleTest(test.TestCase):
class ParseTensorOpTest(test.TestCase):
def testToFloat32(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.float32)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1719,7 +1719,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testToUint8(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1732,7 +1732,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testTypeMismatch(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1745,7 +1745,7 @@ class ParseTensorOpTest(test.TestCase):
tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
def testInvalidInput(self):
- with self.test_session():
+ with self.cached_session():
serialized = array_ops.placeholder(dtypes.string)
tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)