diff options
author | 2018-09-10 14:36:26 -0700 | |
---|---|---|
committer | 2018-09-10 14:49:41 -0700 | |
commit | 890e16594a005fe703a5556530b0dc3e6527fa47 (patch) | |
tree | 99140efb13f392ae13a58f08c08754c61bf66f13 /tensorflow/contrib/crf | |
parent | 132babebf5b1026cb33cad7c4eb7e03810c2acdf (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336321
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/python/kernel_tests/crf_test.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 8cfe142059..556d731840 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -61,7 +61,7 @@ class CrfTest(test.TestCase): for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, inputs_list, tag_indices_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), @@ -96,7 +96,7 @@ class CrfTest(test.TestCase): ] for sequence_lengths, inputs, tag_bitmap in zip( sequence_lengths_list, inputs_list, tag_bitmap_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_multitag_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), @@ -124,7 +124,7 @@ class CrfTest(test.TestCase): for dtype in (np.int32, np.int64): tag_indices = np.array([1, 2, 1, 0], dtype=dtype) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: unary_score = crf.crf_unary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -140,7 +140,7 @@ class CrfTest(test.TestCase): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: binary_score = crf.crf_binary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -176,7 +176,7 @@ class CrfTest(test.TestCase): tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. @@ -206,7 +206,7 @@ class CrfTest(test.TestCase): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], @@ -226,7 +226,7 @@ class CrfTest(test.TestCase): sequence_lengths = np.array(3, dtype=np.int32) num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_log_likelihoods = [] # Make sure all probabilities sum to 1. @@ -254,7 +254,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -310,7 +310,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -351,7 +351,7 @@ class CrfTest(test.TestCase): """ Test that crf_decode works when sequence_length contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], |