aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:36:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 14:49:41 -0700
commit890e16594a005fe703a5556530b0dc3e6527fa47 (patch)
tree99140efb13f392ae13a58f08c08754c61bf66f13 /tensorflow/contrib/crf
parent132babebf5b1026cb33cad7c4eb7e03810c2acdf (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336321
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 8cfe142059..556d731840 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -61,7 +61,7 @@ class CrfTest(test.TestCase):
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
@@ -96,7 +96,7 @@ class CrfTest(test.TestCase):
]
for sequence_lengths, inputs, tag_bitmap in zip(
sequence_lengths_list, inputs_list, tag_bitmap_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_multitag_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
@@ -124,7 +124,7 @@ class CrfTest(test.TestCase):
for dtype in (np.int32, np.int64):
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -140,7 +140,7 @@ class CrfTest(test.TestCase):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
binary_score = crf.crf_binary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -176,7 +176,7 @@ class CrfTest(test.TestCase):
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
# Compare the dynamic program with brute force computation.
@@ -206,7 +206,7 @@ class CrfTest(test.TestCase):
"""
Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
@@ -226,7 +226,7 @@ class CrfTest(test.TestCase):
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_log_likelihoods = []
# Make sure all probabilities sum to 1.
@@ -254,7 +254,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -310,7 +310,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -351,7 +351,7 @@ class CrfTest(test.TestCase):
"""
Test that crf_decode works when sequence_length contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],