diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/batch_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/batch_dataset_op_test.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 89de55dd4f..c48708a2b9 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -82,7 +82,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[dim0] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -111,7 +111,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -131,7 +131,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -158,7 +158,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -188,7 +188,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( @@ -214,7 +214,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -262,7 +262,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=4, padded_shapes=[5]).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.DataLossError): sess.run(get_next) @@ -318,7 +318,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=4, padded_shapes=[-1]).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -342,7 +342,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) sess.run( @@ -381,7 +381,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None]))) padded_dataset = dataset.padded_batch( 2, padded_shapes=([None], [None]), padding_values=('', 0)) - with self.test_session() as sess: + with self.cached_session() as sess: next_element = padded_dataset.make_one_shot_iterator().get_next() sess.run(next_element) |