aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 15:16:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 15:35:53 -0700
commitb2530e2b40b3c55c7121508b224ee1d9ed1bad27 (patch)
treefa4d861ecc90da1f10339c297126c9b82f3ad1b8 /tensorflow/contrib/training
parent0714726b47d0f9f5cace70b3db6578aa62ed394c (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209839032
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py4
-rw-r--r--tensorflow/contrib/training/python/training/bucket_ops_test.py10
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/resample_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_threading_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/training_test.py14
10 files changed, 51 insertions, 51 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 81278ea82c..afeef978f3 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -108,7 +108,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch1, expected_seq4_batch2,
key=None, make_keys_unique=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_batch = sqss.batch_sequences_with_states(
input_key=key if key is not None else self.key,
input_sequences=self.sequences,
@@ -332,7 +332,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
"seq4": self.sequences["seq4"],
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
"value: 4. Consider setting pad=True."):
diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py
index 504f1fcd41..b259e0ee83 100644
--- a/tensorflow/contrib/training/python/training/bucket_ops_test.py
+++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py
@@ -112,7 +112,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(32):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -162,7 +162,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[None], [None, None], [None, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(15):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -204,7 +204,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(64):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -286,7 +286,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(128):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -405,7 +405,7 @@ class BucketBySequenceLengthTest(test.TestCase):
num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
num_pairs_dequeued)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
# Feed the inputs, then close the input thread.
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index c36d00e842..ec47fe5d97 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -67,7 +67,7 @@ class CheckpointIteratorTest(test.TestCase):
global_step = variables.get_or_create_global_step()
saver = saver_lib.Saver() # Saves the global step.
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib.global_variables_initializer())
save_path = os.path.join(checkpoint_dir, 'model.ckpt')
saver.save(session, save_path, global_step=global_step)
diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py
index 774241a816..8665a24883 100644
--- a/tensorflow/contrib/training/python/training/resample_test.py
+++ b/tensorflow/contrib/training/python/training/resample_test.py
@@ -44,7 +44,7 @@ class ResampleTest(test.TestCase):
([3], [0, 0, 0]),
([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]),
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for inputs, expected in cases:
array_inputs = numpy.array(inputs, dtype=numpy.int32)
actual = sess.run(resample._repeat_range(array_inputs))
@@ -65,7 +65,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init) # initialize
# outputs
@@ -112,7 +112,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
expected_sum_op = math_ops.reduce_sum(vals)
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init)
expected_sum = n * s.run(expected_sum_op)
@@ -147,7 +147,7 @@ class ResampleTest(test.TestCase):
resampled = resample.resample_at_rate([vals], rates)
- with self.test_session() as s:
+ with self.cached_session() as s:
rs, = s.run(resampled, {
vals: list(range(count)),
rates: numpy.zeros(
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index bf7fb4fd48..1aeff7dc80 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -146,7 +146,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_label in illegal_labels:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([val_tf, lbl_tf],
feed_dict={label_ph: illegal_label,
@@ -154,7 +154,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_prob in illegal_probs:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([prob_tf],
feed_dict={label_ph: valid_labels,
@@ -172,7 +172,7 @@ class StratifiedSampleTest(test.TestCase):
summary_op = logging_ops.merge_summary(
ops.get_collection(ops.GraphKeys.SUMMARIES))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -197,7 +197,7 @@ class StratifiedSampleTest(test.TestCase):
batch_size,
init_probs=[0, .3, 0, .7, 0],
enqueue_many=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -228,7 +228,7 @@ class StratifiedSampleTest(test.TestCase):
# Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([val_tf, labels_tf],
feed_dict={vals_ph: vals,
labels_ph: labels})
@@ -253,7 +253,7 @@ class StratifiedSampleTest(test.TestCase):
self.assertEqual(len(val_list), len(val_input_batch))
self.assertTrue(isinstance(lbls, ops.Tensor))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -283,7 +283,7 @@ class StratifiedSampleTest(test.TestCase):
# Run session and keep track of how frequently the labels and values appear.
data_l = []
label_l = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
@@ -374,7 +374,7 @@ class RejectionSampleTest(test.TestCase):
'rejection_sample/prob_with_checks:0')
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for illegal_prob in [-0.1, 1.1]:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
@@ -393,7 +393,7 @@ class RejectionSampleTest(test.TestCase):
sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn,
batch_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
index ca78c0029e..73ad859ab3 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
@@ -59,7 +59,7 @@ class SamplingOpsThreadingTest(test.TestCase):
out_tensor = queue.dequeue()
# Run the multi-threaded session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
index 7aebd9d9fe..8932b905c9 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import test
class SequenceQueueingStateSaverTest(test.TestCase):
def testSequenceInputWrapper(self):
- with self.test_session():
+ with self.cached_session():
length = 3
key = "key"
padded_length = 4
@@ -54,7 +54,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertTrue(isinstance(input_wrapper.context["context1"], ops.Tensor))
def testStateSaverWithTwoSimpleSteps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 2
batch_size = constant_op.constant(batch_size_value)
num_unroll = 2
@@ -159,7 +159,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(0, state_saver.barrier.ready_size().eval())
def testStateSaverFailsIfPaddedLengthIsNotMultipleOfNumUnroll(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
bad_padded_length = 3
@@ -194,7 +194,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
})
def _testStateSaverFailsIfCapacityTooSmall(self, batch_size):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_unroll = 2
length = array_ops.placeholder(dtypes.int32)
key = array_ops.placeholder(dtypes.string)
@@ -243,7 +243,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self._testStateSaverFailsIfCapacityTooSmall(batch_size)
def testStateSaverFailsIfInconsistentPaddedLength(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -282,7 +282,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverFailsIfInconsistentWriteState(self):
# TODO(b/26910386): Identify why this infrequently causes timeouts.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(1)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -326,7 +326,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverWithManyInputsReadWriteThread(self):
batch_size_value = 32
num_proc_threads = 100
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -490,7 +490,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertGreater(processed_count[0], 2 * 20 * batch_size_value)
def testStateSaverProcessesExamplesInOrder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 32
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
@@ -563,7 +563,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(get_ready_size.eval(), 0)
def testStateSaverCanHandleVariableBatchsize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = array_ops.placeholder(dtypes.int32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
index 4a46e9a49e..3269d5fef2 100644
--- a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -62,7 +62,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
"""Get an array with learning rate values from the consecutive steps
using current tensorflow implementation."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
@@ -76,7 +76,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
"""Compare values generated by tensorflow implementation to the values
generated by the original implementation
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
- with self.test_session():
+ with self.cached_session():
lr = 10.0
init_steps = 2
t_mul = 3
@@ -92,7 +92,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testMDecay(self):
"""Test m_mul argument. Check values for learning rate at the beginning
of the first, second, third and fourth period. """
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.1
@@ -121,7 +121,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testCos(self):
"""Check learning rate values at the beginning, in the middle
and at the end of the period."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.2
t_e = 1000
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index df0a186f4f..d9b0511a98 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -79,7 +79,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0, 0]], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([[1, 0, 0]], value_1)
@@ -101,7 +101,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertEqual([1], value_1)
@@ -126,7 +126,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]],
array_ops.expand_dims(
value[0], axis=0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([0, 1], value_0)
value_1, _ = sess.run([value, enqueue_zeroth])
@@ -147,7 +147,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i)
for i in range(1000)
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run((value, enqueue_many_more))
self.assertEqual([0], value_0)
rest = []
@@ -174,7 +174,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = 0
while i < 4:
received, _ = sess.run((value, enqueue))
@@ -199,7 +199,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
batch_size=1, padded_shapes=[2]))
iterator = dataset.make_one_shot_iterator()
_, value = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
r"Incompatible input shapes at component 0 between "
r"input dataset this dataset: \[3\] vs. \[2\]"):
@@ -224,7 +224,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
np.array(
[[1]], dtype=np.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
"mismatched number of tensors. Queue expects 1 tensors but "
"tried to insert 2"):
@@ -274,7 +274,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
with ops.control_dependencies([enqueue_rest_op]):
calc = array_ops.identity(value_head)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc))
self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc))
self.assertAllEqual([[6, 6]], sess.run(calc))
@@ -304,7 +304,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
_, (unused_count, padded_value) = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]],
sess.run(padded_value))
self.assertAllEqual([[6] * 6], sess.run(padded_value))
diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py
index 94cf7788b2..3b524ac8c7 100644
--- a/tensorflow/contrib/training/python/training/training_test.py
+++ b/tensorflow/contrib/training/python/training/training_test.py
@@ -62,7 +62,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms(
gradients_to_variables, 3.0)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -75,7 +75,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)(
gradients_to_variables)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -122,7 +122,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -155,7 +155,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -186,7 +186,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -209,7 +209,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -535,7 +535,7 @@ class TrainTest(test.TestCase):
train_biases = training.create_train_op(
total_loss, optimizer, variables_to_train=[biases])
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize the variables.
session.run(variables_lib2.global_variables_initializer())