aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:37:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:04:14 -0700
commitb828f89263e054bfa7c7a808cab1506834ab906d (patch)
treee31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/contrib/learn
parentacf0ee82092727afc2067316982407cf5e496f75 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py6
2 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923db3..ff190110c1 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@ class OpsTest(test.TestCase):
"""Ops tests."""
def test_softmax_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
features = array_ops.placeholder(dtypes.float32, [None, 3])
labels = array_ops.placeholder(dtypes.float32, [None, 2])
weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@ class OpsTest(test.TestCase):
ids_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
ids = np.random.randint(0, n_embed, ids_shape)
- with self.test_session():
+ with self.cached_session():
embed_np = embeds[ids]
embed_tf = ops.embedding_lookup(embeds, ids).eval()
self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@ class OpsTest(test.TestCase):
def test_categorical_variable(self):
random_seed.set_random_seed(42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
embeddings = ops.categorical_variable(
cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61955..5a7e4ebfea 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase):
"""Sequence-to-sequence tests."""
def test_sequence_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
decoding = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
@@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase):
def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
- with self.test_session() as session:
+ with self.cached_session() as session:
x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase):
[[0, 0, 0], [0, 0, 0]]])
def test_rnn_decoder(self):
- with self.test_session():
+ with self.cached_session():
decoder_inputs = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]