diff options
Diffstat (limited to 'tensorflow/python/training/input_test.py')
-rw-r--r-- | tensorflow/python/training/input_test.py | 94 |
1 files changed, 47 insertions, 47 deletions
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 1b1e89cb26..a9b05dcc73 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -51,7 +51,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase): for name in additional: open(name, "w").write("Some contents") filenames = list(set(filenames + additional)) - with self.test_session(): + with self.cached_session(): star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*")) question = inp.match_filenames_once( os.path.join(self.get_temp_dir(), "match_filenames.?")) @@ -66,7 +66,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase): class LimitEpochsTest(test_lib.TestCase): def testNoLimit(self): - with self.test_session(): + with self.cached_session(): seven = constant_op.constant(7) seven_forever = inp.limit_epochs(seven) variables.local_variables_initializer().run() @@ -74,7 +74,7 @@ class LimitEpochsTest(test_lib.TestCase): self.assertEqual(7, seven_forever.eval()) def testLimit(self): - with self.test_session(): + with self.cached_session(): love_me = constant_op.constant("Love Me") love_me_two_times = inp.limit_epochs(love_me, num_epochs=2) variables.global_variables_initializer().run() @@ -88,7 +88,7 @@ class LimitEpochsTest(test_lib.TestCase): class InputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): input_tensor = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] @@ -111,7 +111,7 @@ class InputProducerTest(test_lib.TestCase): thread.join() def testNoShapeInference(self): - with self.test_session(): + with self.cached_session(): # Disable shape inference for the input. input_value = [[1, 2, 3, 4], [5, 6, 7, 8], @@ -144,7 +144,7 @@ class InputProducerTest(test_lib.TestCase): class StringInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): strings = [b"to", b"be", b"or", b"not", b"to", b"be"] num_epochs = 3 queue = inp.string_input_producer( @@ -166,7 +166,7 @@ class StringInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session(): + with self.cached_session(): strings = [b"a", b"b", b"c"] num_epochs = 600 queue = inp.string_input_producer( @@ -206,7 +206,7 @@ class StringInputProducerTest(test_lib.TestCase): def testNullStringPython(self): # Graph-construction time check for empty string list: - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): _ = inp.string_input_producer([]) @@ -214,7 +214,7 @@ class StringInputProducerTest(test_lib.TestCase): # Runtime check for empty string list. This is slightly oblique: # The queue runner should die with an assertion error on the null # input tensor, causing the dequeue to fail with an OutOfRangeError. - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() queue = inp.string_input_producer( constant_op.constant( @@ -230,7 +230,7 @@ class StringInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): strings = [b"to", b"be", b"or", b"not", b"to", b"be"] queue = inp.string_input_producer( strings, shared_name="SHARED_NAME_XYZ", name="Q") @@ -238,7 +238,7 @@ class StringInputProducerTest(test_lib.TestCase): queue.queue_ref.op.node_def.attr["shared_name"]) def testConstructionRace(self): - with self.test_session() as sess: + with self.cached_session() as sess: strings = [b"to", b"be", b"or", b"not", b"to", b"be"] queue = inp.string_input_producer(strings, shuffle=False) coord = coordinator.Coordinator() @@ -260,7 +260,7 @@ class StringInputProducerTest(test_lib.TestCase): class RangeInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): num_epochs = 3 range_size = 5 queue = inp.range_input_producer( @@ -282,7 +282,7 @@ class RangeInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session(): + with self.cached_session(): num_epochs = 200 range_size = 2 queue = inp.range_input_producer( @@ -321,7 +321,7 @@ class RangeInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): range_size = 5 queue = inp.range_input_producer( range_size, shared_name="SHARED_NAME_XYZ", name="Q") @@ -332,7 +332,7 @@ class RangeInputProducerTest(test_lib.TestCase): class SliceInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_epochs = 3 source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"] source_ints = [2, 3, 5, 7] @@ -356,7 +356,7 @@ class SliceInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_epochs = 1200 source_strings = ["A", "B", "D", "G"] source_ints = [7, 3, 5, 2] @@ -400,7 +400,7 @@ class SliceInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): source_strings = ["A", "B", "D", "G"] source_ints = [7, 3, 5, 2] slices = inp.slice_input_producer( @@ -440,7 +440,7 @@ class DictHelperTest(test_lib.TestCase): class BatchTest(test_lib.TestCase): def _testOneThreadHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -500,7 +500,7 @@ class BatchTest(test_lib.TestCase): def testUint32DataTypes(self): values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32) batched = inp.batch([values], batch_size=2) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) sess.run(batched) @@ -511,7 +511,7 @@ class BatchTest(test_lib.TestCase): def testUint64DataTypes(self): values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64) batched = inp.batch([values], batch_size=2) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) sess.run(batched) @@ -520,7 +520,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadDynamicPad(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -550,7 +550,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadEnqueueMany(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -585,7 +585,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testManyThreads(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -625,7 +625,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -682,7 +682,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testManyThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -737,7 +737,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -754,7 +754,7 @@ class BatchTest(test_lib.TestCase): batched[0].op.inputs[0].op.node_def.attr["shared_name"]) def testCannotInferRankError(self): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtype=dtypes.int64) with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): inp.batch([x], batch_size=2) @@ -797,7 +797,7 @@ class BatchTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -934,7 +934,7 @@ class BatchTest(test_lib.TestCase): batched = inp.maybe_batch( [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True) - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -952,7 +952,7 @@ class BatchTest(test_lib.TestCase): class BatchJoinTest(test_lib.TestCase): def _testTwoThreadsHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..69, "a"). num_a = 70 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1069,7 +1069,7 @@ class BatchJoinTest(test_lib.TestCase): batch_size=8) def DISABLED_testTwoThreadsDynamicPad(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..69, ["a"] * 1..70). num_a = 70 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1144,7 +1144,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def DISABLED_testTwoThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, "a"). num_a = 70 + extra_elements @@ -1243,7 +1243,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, ["a"] * 1..70). num_a = 70 + extra_elements @@ -1338,7 +1338,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1360,7 +1360,7 @@ class BatchJoinTest(test_lib.TestCase): batched[0].op.inputs[0].op.node_def.attr["shared_name"]) def testCannotInferRankError(self): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtype=dtypes.int64) with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): inp.batch_join([[x]], batch_size=2) @@ -1371,7 +1371,7 @@ class BatchJoinTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -1511,7 +1511,7 @@ class BatchJoinTest(test_lib.TestCase): batched = inp.maybe_batch_join( [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True) - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -1529,7 +1529,7 @@ class BatchJoinTest(test_lib.TestCase): class ShuffleBatchTest(test_lib.TestCase): def _testOneThreadHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1594,7 +1594,7 @@ class ShuffleBatchTest(test_lib.TestCase): self._testOneThreadHelper(use_dict=True) def testOneThreadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -1650,7 +1650,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testManyThreads(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1697,7 +1697,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testManyThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -1755,7 +1755,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1775,7 +1775,7 @@ class ShuffleBatchTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -1906,7 +1906,7 @@ class ShuffleBatchTest(test_lib.TestCase): class ShuffleBatchJoinTest(test_lib.TestCase): def _testTwoThreadsHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..24, "a"). num_a = 25 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -2017,7 +2017,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): self._testTwoThreadsHelper(use_dict=True) def testTwoThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..26, "a"). extra_elements = 2 num_a = 25 + extra_elements @@ -2137,7 +2137,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): seed=223607) def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -2162,7 +2162,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) |