aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/input_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/input_test.py')
-rw-r--r--tensorflow/python/training/input_test.py94
1 files changed, 47 insertions, 47 deletions
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 1b1e89cb26..a9b05dcc73 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -51,7 +51,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
for name in additional:
open(name, "w").write("Some contents")
filenames = list(set(filenames + additional))
- with self.test_session():
+ with self.cached_session():
star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
question = inp.match_filenames_once(
os.path.join(self.get_temp_dir(), "match_filenames.?"))
@@ -66,7 +66,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
class LimitEpochsTest(test_lib.TestCase):
def testNoLimit(self):
- with self.test_session():
+ with self.cached_session():
seven = constant_op.constant(7)
seven_forever = inp.limit_epochs(seven)
variables.local_variables_initializer().run()
@@ -74,7 +74,7 @@ class LimitEpochsTest(test_lib.TestCase):
self.assertEqual(7, seven_forever.eval())
def testLimit(self):
- with self.test_session():
+ with self.cached_session():
love_me = constant_op.constant("Love Me")
love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
variables.global_variables_initializer().run()
@@ -88,7 +88,7 @@ class LimitEpochsTest(test_lib.TestCase):
class InputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
input_tensor = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
@@ -111,7 +111,7 @@ class InputProducerTest(test_lib.TestCase):
thread.join()
def testNoShapeInference(self):
- with self.test_session():
+ with self.cached_session():
# Disable shape inference for the input.
input_value = [[1, 2, 3, 4],
[5, 6, 7, 8],
@@ -144,7 +144,7 @@ class InputProducerTest(test_lib.TestCase):
class StringInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
num_epochs = 3
queue = inp.string_input_producer(
@@ -166,7 +166,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"a", b"b", b"c"]
num_epochs = 600
queue = inp.string_input_producer(
@@ -206,7 +206,7 @@ class StringInputProducerTest(test_lib.TestCase):
def testNullStringPython(self):
# Graph-construction time check for empty string list:
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
_ = inp.string_input_producer([])
@@ -214,7 +214,7 @@ class StringInputProducerTest(test_lib.TestCase):
# Runtime check for empty string list. This is slightly oblique:
# The queue runner should die with an assertion error on the null
# input tensor, causing the dequeue to fail with an OutOfRangeError.
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
queue = inp.string_input_producer(
constant_op.constant(
@@ -230,7 +230,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(
strings, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -238,7 +238,7 @@ class StringInputProducerTest(test_lib.TestCase):
queue.queue_ref.op.node_def.attr["shared_name"])
def testConstructionRace(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(strings, shuffle=False)
coord = coordinator.Coordinator()
@@ -260,7 +260,7 @@ class StringInputProducerTest(test_lib.TestCase):
class RangeInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 3
range_size = 5
queue = inp.range_input_producer(
@@ -282,7 +282,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 200
range_size = 2
queue = inp.range_input_producer(
@@ -321,7 +321,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
range_size = 5
queue = inp.range_input_producer(
range_size, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -332,7 +332,7 @@ class RangeInputProducerTest(test_lib.TestCase):
class SliceInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 3
source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
source_ints = [2, 3, 5, 7]
@@ -356,7 +356,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 1200
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
@@ -400,7 +400,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
slices = inp.slice_input_producer(
@@ -440,7 +440,7 @@ class DictHelperTest(test_lib.TestCase):
class BatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -500,7 +500,7 @@ class BatchTest(test_lib.TestCase):
def testUint32DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -511,7 +511,7 @@ class BatchTest(test_lib.TestCase):
def testUint64DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -520,7 +520,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -550,7 +550,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -585,7 +585,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -625,7 +625,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -682,7 +682,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -737,7 +737,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -754,7 +754,7 @@ class BatchTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch([x], batch_size=2)
@@ -797,7 +797,7 @@ class BatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -934,7 +934,7 @@ class BatchTest(test_lib.TestCase):
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -952,7 +952,7 @@ class BatchTest(test_lib.TestCase):
class BatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, "a").
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1069,7 +1069,7 @@ class BatchJoinTest(test_lib.TestCase):
batch_size=8)
def DISABLED_testTwoThreadsDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1144,7 +1144,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, "a").
num_a = 70 + extra_elements
@@ -1243,7 +1243,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70 + extra_elements
@@ -1338,7 +1338,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1360,7 +1360,7 @@ class BatchJoinTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch_join([[x]], batch_size=2)
@@ -1371,7 +1371,7 @@ class BatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1511,7 +1511,7 @@ class BatchJoinTest(test_lib.TestCase):
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -1529,7 +1529,7 @@ class BatchJoinTest(test_lib.TestCase):
class ShuffleBatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1594,7 +1594,7 @@ class ShuffleBatchTest(test_lib.TestCase):
self._testOneThreadHelper(use_dict=True)
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1650,7 +1650,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1697,7 +1697,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1755,7 +1755,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1775,7 +1775,7 @@ class ShuffleBatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1906,7 +1906,7 @@ class ShuffleBatchTest(test_lib.TestCase):
class ShuffleBatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..24, "a").
num_a = 25
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2017,7 +2017,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
self._testTwoThreadsHelper(use_dict=True)
def testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..26, "a").
extra_elements = 2
num_a = 25 + extra_elements
@@ -2137,7 +2137,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
seed=223607)
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2162,7 +2162,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)