diff options
author | 2018-08-21 19:53:43 -0700 | |
---|---|---|
committer | 2018-08-21 20:00:41 -0700 | |
commit | 47c0bda0e7f736a9328aaf76aba7c8006e24556f (patch) | |
tree | ad2a6ab71adddc0d07c7f306c270122937b6a5b0 /tensorflow/contrib/learn | |
parent | 1ab795b54274a26a92690f36eff65674fb500f91 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 209703607
Diffstat (limited to 'tensorflow/contrib/learn')
3 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index c9a11f27f1..1d8a59281a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -155,7 +155,7 @@ class DynamicRnnEstimatorTest(test.TestCase): sequence_input = dynamic_rnn_estimator.build_sequence_input( self.GetColumnsToTensors(), self.sequence_feature_columns, self.context_feature_columns) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) sequence_input_val = sess.run(sequence_input) @@ -330,7 +330,7 @@ class DynamicRnnEstimatorTest(test.TestCase): actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell) flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state) - with self.test_session() as sess: + with self.cached_session() as sess: (state_dict_val, actual_state_val, flattened_state_val) = sess.run( [state_dict, actual_state, flattened_state]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py index 82563141cc..ebf5f5617d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py @@ -44,7 +44,7 @@ class RnnCommonTest(test.TestCase): constant_op.constant(labels, dtype=dtypes.int32), constant_op.constant(sequence_length, dtype=dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: activations_masked, labels_masked = sess.run( [activations_masked_t, labels_masked_t]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 442247409d..06c61554fa 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -53,7 +53,7 @@ class PrepareInputsForRnnTest(test.TestCase): sequence_feature_columns, num_unroll) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) features_val = sess.run(features_by_time) @@ -314,7 +314,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): else: self.assertAllEqual(v, got[k]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) actual_sequence, actual_context = sess.run( |