diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 00:02:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 00:07:39 -0700 |
commit | 2952f5134905af795ba90ae1eb97e39091ba9843 (patch) | |
tree | f73bc5cd0342d9449114bd933863c2aa55610aa2 /tensorflow/contrib/batching | |
parent | cf047f7755f3400ee128db2571042091fe9f8314 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 213944355
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r-- | tensorflow/contrib/batching/python/ops/batch_ops_test.py | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 7846814546..01ee8703a9 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase): def testBasicBatch(self): """Tests that a single batched tensor executes together and only once.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase): def testBatchWithPadding(self): """Test that batching with padding up to an allowed batch size works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase): def testMultipleBatch(self): """Tests that multiple batched tensors execute together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, _, _ = batch_ops.batch( @@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase): def testIllegalBatchDifferentDim0Sizes(self): """Tests illegally feeding tensors with different dim0 sizes.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( @@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatch(self): """Tests that batch and unbatch work together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchV1Decorated(self): """Tests that the batch_function_v1 decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) def computation(in_t): return in_t + 1 @@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: # TODO(apassos): Removing this line causes test flakiness! Ideally should # be investigated. default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable @@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase): def testBatchDecoratedWithCapturedInput(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOp(self): """Tests that the batch_function op works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): @@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithCapturedInput(self): """Tests that batch_function op works with captured input.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithInputError(self): """Tests that batch_function op works with error in the inputs.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) @@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @batch_ops.batch_function(1, 10, 100000) def computation(in_t): @@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, |