diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/parsing_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/parsing_ops_test.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 7dff4501cc..71d8b60d3c 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors, class ParseExampleTest(test.TestCase): def _test(self, kwargs, expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): @@ -937,7 +937,7 @@ class ParseExampleTest(test.TestCase): class ParseSingleExampleTest(test.TestCase): def _test(self, kwargs, expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): @@ -1054,7 +1054,7 @@ class ParseSequenceExampleTest(test.TestCase): expected_feat_list_values = expected_feat_list_values or {} expected_length_values = expected_length_values or {} - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): @@ -1606,7 +1606,7 @@ class ParseSequenceExampleTest(test.TestCase): class DecodeJSONExampleTest(test.TestCase): def _testRoundTrip(self, examples): - with self.test_session() as sess: + with self.cached_session() as sess: examples = np.array(examples, dtype=np.object) json_tensor = constant_op.constant( @@ -1696,7 +1696,7 @@ class DecodeJSONExampleTest(test.TestCase): ]) def testInvalidSyntax(self): - with self.test_session() as sess: + with self.cached_session() as sess: json_tensor = constant_op.constant(["{]"]) binary_tensor = parsing_ops.decode_json_example(json_tensor) with self.assertRaisesOpError("Error while parsing JSON"): @@ -1706,7 +1706,7 @@ class DecodeJSONExampleTest(test.TestCase): class ParseTensorOpTest(test.TestCase): def testToFloat32(self): - with self.test_session(): + with self.cached_session(): expected = np.random.rand(3, 4, 5).astype(np.float32) tensor_proto = tensor_util.make_tensor_proto(expected) @@ -1719,7 +1719,7 @@ class ParseTensorOpTest(test.TestCase): self.assertAllEqual(expected, result) def testToUint8(self): - with self.test_session(): + with self.cached_session(): expected = np.random.rand(3, 4, 5).astype(np.uint8) tensor_proto = tensor_util.make_tensor_proto(expected) @@ -1732,7 +1732,7 @@ class ParseTensorOpTest(test.TestCase): self.assertAllEqual(expected, result) def testTypeMismatch(self): - with self.test_session(): + with self.cached_session(): expected = np.random.rand(3, 4, 5).astype(np.uint8) tensor_proto = tensor_util.make_tensor_proto(expected) @@ -1745,7 +1745,7 @@ class ParseTensorOpTest(test.TestCase): tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()}) def testInvalidInput(self): - with self.test_session(): + with self.cached_session(): serialized = array_ops.placeholder(dtypes.string) tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16) |