diff options
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py | 54 |
1 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 67242fecfe..8e368bf2bc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an input tensor of incompatible rank. sess.run(init_op, feed_dict={input_tensor: [[1]]}) with self.assertRaisesRegexp(errors.InvalidArgumentError, @@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i,) * 3, sess.run(op)) @@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) @@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): st_row = sess.run(next_element) self.assertEqual([i], st_row.indices) @@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): dense_elem, st_row = sess.run(next_element) self.assertEqual(i, dense_elem) @@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) @@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) @@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Mismatch in the 0th dimension. sess.run( iterator.initializer, @@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Batch of a finite input, where the batch_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 28, batch_size: 14}) @@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) if not drop_remainder: @@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_one_shot_iterator()) self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) self.assertAllEqual([[64], [81]], sess.run(next_element)) @@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) .make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) @@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "number of elements does not match"): @@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(3): sess.run(get_next) @@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=10)).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) if threshold % 10 != 0: @@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) @@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) |