aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 00:02:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 00:07:39 -0700
commit2952f5134905af795ba90ae1eb97e39091ba9843 (patch)
treef73bc5cd0342d9449114bd933863c2aa55610aa2 /tensorflow/contrib/batching
parentcf047f7755f3400ee128db2571042091fe9f8314 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 213944355
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py29
1 files changed, 15 insertions, 14 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814546..01ee8703a9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase):
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase):
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase):
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase):
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
@@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase):
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,