aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:09:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 19:13:28 -0700
commit754fffb399efa6204bb8aae51ce99042cb2ab18e (patch)
tree3f3a3ecd5e25bac3a4babd9ca330f63d21fb2918 /tensorflow/contrib/framework
parent34f07dc58afcbddf3c4387cdf7c49ebb5aacf4dd (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209700634
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/arg_scope_test.py18
-rw-r--r--tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/sort_ops_test.py12
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py140
5 files changed, 93 insertions, 93 deletions
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
index bcafc1a328..0e6c6f0e2f 100644
--- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py
+++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
@@ -52,7 +52,7 @@ def _key_op(op):
class ArgScopeTest(test.TestCase):
def testEmptyArgScope(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([]) as sc:
self.assertEqual(sc, {})
@@ -60,7 +60,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
func1_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as sc1:
self.assertEqual(sc1, func1_scope)
with arg_scope({}) as sc2:
@@ -86,7 +86,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
current_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope:
self.assertDictEqual(scope, current_scope)
@@ -102,7 +102,7 @@ class ArgScopeTest(test.TestCase):
key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()
}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]):
with arg_scope([func2], b=2, d=[2]) as scope:
self.assertDictEqual(scope, current_scope)
@@ -111,7 +111,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
current_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
pass
with arg_scope(scope1) as scope:
@@ -126,7 +126,7 @@ class ArgScopeTest(test.TestCase):
key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()
}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
with arg_scope([func2], b=2, d=[2]) as scope2:
pass
@@ -140,7 +140,7 @@ class ArgScopeTest(test.TestCase):
def testSimpleArgScope(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]):
args, kwargs = func1(0)
self.assertTupleEqual(args, func1_args)
@@ -149,7 +149,7 @@ class ArgScopeTest(test.TestCase):
def testSimpleArgScopeWithTuple(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
- with self.test_session():
+ with self.cached_session():
with arg_scope((func1,), a=1, b=None, c=[1]):
args, kwargs = func1(0)
self.assertTupleEqual(args, func1_args)
@@ -240,7 +240,7 @@ class ArgScopeTest(test.TestCase):
def testAddArgScopeRaceCondition(self):
func4_kwargs = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h')
for i in range(4):
- # redefine the function with different args
+ # redefine the function with different args
@add_arg_scope
def func4(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8):
pass
diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
index b7b9f5c59e..4036c87b6d 100644
--- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
@@ -50,7 +50,7 @@ class LoadMulticlassBiasTest(test.TestCase):
bias = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='bias')
save = saver.Saver([bias])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bias_checkpoint')
save.save(sess, self.bundle_file)
@@ -90,7 +90,7 @@ class LoadMulticlassBiasTest(test.TestCase):
initializer=bias_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(3))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_bias_vector,
remapped_bias_vector.as_tensor().eval())
@@ -109,7 +109,7 @@ class LoadVariableSlotTest(test.TestCase):
accum = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='accum')
save = saver.Saver([accum])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'accum_checkpoint')
save.save(sess, self.bundle_file)
@@ -179,7 +179,7 @@ class LoadVariableSlotTest(test.TestCase):
shape=[2, 1],
initializer=variable_slot_initializer_part_1)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_accum_vector_part_0,
remapped_accum_vector_part_0.eval())
diff --git a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
index 50bcbe625d..c104c51fef 100644
--- a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
@@ -34,7 +34,7 @@ class PrettyPrintOpsTest(test.TestCase):
def testPrintTensorPassthrough(self):
a = constant_op.constant([1])
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(a.eval(), constant_op.constant([1]).eval())
def testPrintSparseTensorPassthrough(self):
@@ -43,7 +43,7 @@ class PrettyPrintOpsTest(test.TestCase):
b = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(a).eval(),
sparse_ops.sparse_tensor_to_dense(b).eval())
@@ -54,13 +54,13 @@ class PrettyPrintOpsTest(test.TestCase):
a = a.write(1, 1)
a = a.write(0, 0)
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(a.stack().eval(), constant_op.constant([0, 1]).eval())
def testPrintVariable(self):
a = variables.Variable(1.0)
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
a.eval()
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
index a8fb94b245..791b32cd1e 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
@@ -48,7 +48,7 @@ class SortTest(test.TestCase):
sort_axis = np.random.choice(rank)
if negative_axis:
sort_axis = -1 - sort_axis
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -60,7 +60,7 @@ class SortTest(test.TestCase):
shape = [np.random.randint(1, 4) for _ in range(rank)]
arr = np.random.random(shape)
sort_axis = np.random.choice(rank)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -73,7 +73,7 @@ class SortTest(test.TestCase):
scalar = array_ops.zeros(zeros_length_1)
sort = sort_ops.sort(scalar)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors.InvalidArgumentError):
sort.eval()
@@ -84,7 +84,7 @@ class SortTest(test.TestCase):
def testDescending(self):
arr = np.random.random((10, 5, 5))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=0)[::-1],
sort_ops.sort(
@@ -111,7 +111,7 @@ class SortTest(test.TestCase):
def testArgsort_1d(self):
arr = np.random.random(42)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr),
array_ops.gather(arr, sort_ops.argsort(arr)).eval())
@@ -119,7 +119,7 @@ class SortTest(test.TestCase):
def testArgsort(self):
arr = np.random.random((5, 6, 7, 8))
for axis in range(4):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.argsort(arr, axis=axis),
sort_ops.argsort(arr, axis=axis).eval())
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 3c44630a51..f9b0efd1da 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -45,7 +45,7 @@ from tensorflow.python.training import saver as saver_lib
class LocalVariableTest(test.TestCase):
def test_local_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
@@ -58,7 +58,7 @@ class LocalVariableTest(test.TestCase):
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testLocalVariableNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable([1, 1, 1, 1, 1], name='a')
self.assertEquals(a.op.name, 'A/a')
@@ -66,21 +66,21 @@ class LocalVariableTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_local_variables())
def testLocalVariableNotInAllVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertFalse(a in variables_lib.global_variables())
self.assertTrue(a in variables_lib.local_variables())
def testLocalVariableNotInVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertFalse(a in variables_lib2.get_variables_to_restore())
self.assertTrue(a in variables_lib.local_variables())
def testGetVariablesDontReturnsTransients(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable(0)
with variable_scope.variable_scope('B'):
@@ -89,7 +89,7 @@ class LocalVariableTest(test.TestCase):
self.assertEquals([], variables_lib2.get_variables('B'))
def testGetLocalVariablesReturnsTransients(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
with variable_scope.variable_scope('B'):
@@ -98,7 +98,7 @@ class LocalVariableTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.local_variable([0, 0, 0, 0, 0], name='a')
sess.run(variables_lib.local_variables_initializer())
self.assertAllEqual(a.eval(), [0] * 5)
@@ -114,7 +114,7 @@ class LocalVariableTest(test.TestCase):
class GlobalVariableTest(test.TestCase):
def test_global_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.global_variables())
value0 = 42
variables_lib2.global_variable(value0)
@@ -129,7 +129,7 @@ class GlobalVariableTest(test.TestCase):
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testVariableNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a')
self.assertEquals(a.op.name, 'A/a')
@@ -137,21 +137,21 @@ class GlobalVariableTest(test.TestCase):
self.assertListEqual([a], variables_lib.global_variables())
def testGlobalVariableNotInLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
self.assertFalse(a in variables_lib.local_variables())
self.assertTrue(a in variables_lib.global_variables())
def testGlobalVariableInVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
self.assertFalse(a in variables_lib.local_variables())
self.assertTrue(a in variables_lib2.get_variables_to_restore())
def testGetVariablesReturnsThem(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
with variable_scope.variable_scope('B'):
@@ -160,7 +160,7 @@ class GlobalVariableTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetLocalVariablesDontReturnsThem(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.global_variable(0)
with variable_scope.variable_scope('B'):
@@ -169,7 +169,7 @@ class GlobalVariableTest(test.TestCase):
self.assertEquals([], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a')
sess.run(variables_lib.global_variables_initializer())
self.assertAllEqual(a.eval(), [0] * 5)
@@ -249,7 +249,7 @@ class GlobalStepTest(test.TestCase):
class VariablesTest(test.TestCase):
def testCreateVariable(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
@@ -259,7 +259,7 @@ class VariablesTest(test.TestCase):
self.assertFalse(a in variables_lib.local_variables())
def testGetVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -269,7 +269,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetVariablesWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A') as var_scope:
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -277,7 +277,7 @@ class VariablesTest(test.TestCase):
set([a, b]), set(variables_lib2.get_variables(var_scope)))
def testGetVariablesSuffix(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('A'):
@@ -286,13 +286,13 @@ class VariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables(suffix='b'))
def testGetVariableWithSingleVar(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('parent'):
a = variables_lib2.variable('child', [5])
self.assertEquals(a, variables_lib2.get_unique_variable('parent/child'))
def testGetVariableWithDistractors(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('parent'):
a = variables_lib2.variable('child', [5])
with variable_scope.variable_scope('child'):
@@ -302,13 +302,13 @@ class VariablesTest(test.TestCase):
def testGetVariableThrowsExceptionWithNoMatch(self):
var_name = 'cant_find_me'
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
variables_lib2.get_unique_variable(var_name)
def testGetThrowsExceptionWithChildrenButNoMatch(self):
var_name = 'parent/child'
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(var_name):
variables_lib2.variable('grandchild1', [7])
variables_lib2.variable('grandchild2', [9])
@@ -316,7 +316,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_unique_variable(var_name)
def testGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -324,7 +324,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([a, b], variables_lib2.get_variables_to_restore())
def testIncludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -333,7 +333,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([a], variables_lib2.get_variables_to_restore(['A']))
def testExcludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -343,7 +343,7 @@ class VariablesTest(test.TestCase):
[a], variables_lib2.get_variables_to_restore(exclude=['B']))
def testWrongIncludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -352,7 +352,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([], variables_lib2.get_variables_to_restore(['a']))
def testGetMixedVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -365,7 +365,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_variables_to_restore(include=['A/a', 'B/c']))
def testExcludeGetMixedVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -378,7 +378,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_variables_to_restore(exclude=['A/a', 'B/c']))
def testReuseVariable(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [])
with variable_scope.variable_scope('A', reuse=True):
@@ -387,14 +387,14 @@ class VariablesTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_variables())
def testVariableWithRegularizer(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [], regularizer=nn_ops.l2_loss)
loss = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithRegularizerColocate(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable(
'a', [], device='gpu:0', regularizer=nn_ops.l2_loss)
@@ -402,7 +402,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithDevice(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [], device='cpu:0')
b = variables_lib2.variable('b', [], device='cpu:1')
@@ -410,7 +410,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(b.device, 'cpu:1')
def testVariableWithDeviceFromScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.device('/cpu:0'):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [], device='cpu:1')
@@ -428,7 +428,7 @@ class VariablesTest(test.TestCase):
self.counter += 1
return 'cpu:%d' % self.counter
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], device=DevFn()):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
@@ -453,7 +453,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
def testVariableWithReplicaDeviceSetter(self):
- with self.test_session():
+ with self.cached_session():
with ops.device(device_setter.replica_device_setter(ps_tasks=2)):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
@@ -570,7 +570,7 @@ class VariablesTest(test.TestCase):
class ModelVariablesTest(test.TestCase):
def testNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
@@ -578,7 +578,7 @@ class ModelVariablesTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_model_variables('A'))
def testNotInLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
self.assertTrue(a in variables_lib.global_variables())
@@ -586,7 +586,7 @@ class ModelVariablesTest(test.TestCase):
self.assertFalse(a in variables_lib.local_variables())
def testGetVariablesReturns(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -595,7 +595,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetModelVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -604,7 +604,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_model_variables('B'))
def testGetTrainableVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable([5])
a = variables_lib.Variable([5])
@@ -615,7 +615,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_trainable_variables('B'))
def testGetLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
_ = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -624,7 +624,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.model_variable(
'a', [5], initializer=init_ops.ones_initializer())
sess.run(variables_lib.global_variables_initializer())
@@ -670,14 +670,14 @@ class ModelVariablesTest(test.TestCase):
class GetVariablesCollections(test.TestCase):
def testVariableCollection(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [], collections='A')
b = variables_lib2.variable('b', [], collections='B')
self.assertEquals(a, ops.get_collection('A')[0])
self.assertEquals(b, ops.get_collection('B')[0])
def testVariableCollections(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [], collections=['A', 'C'])
b = variables_lib2.variable('b', [], collections=['B', 'C'])
self.assertEquals(a, ops.get_collection('A')[0])
@@ -685,14 +685,14 @@ class GetVariablesCollections(test.TestCase):
self.assertListEqual([a, b], ops.get_collection('C'))
def testVariableCollectionsWithArgScope(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
self.assertListEqual([a, b], ops.get_collection('A'))
def testVariableCollectionsWithArgScopeNested(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
with arg_scope([variables_lib2.variable], collections='B'):
@@ -701,7 +701,7 @@ class GetVariablesCollections(test.TestCase):
self.assertEquals(b, ops.get_collection('B')[0])
def testVariableCollectionsWithArgScopeNonNested(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
with arg_scope([variables_lib2.variable], collections='B'):
@@ -711,7 +711,7 @@ class GetVariablesCollections(test.TestCase):
self.assertListEqual([b], ops.get_collection('B'))
def testVariableRestoreWithArgScopeNested(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [])
with arg_scope(
[variables_lib2.variable], trainable=False, collections=['A', 'B']):
@@ -726,7 +726,7 @@ class GetVariablesCollections(test.TestCase):
class GetVariablesBySuffixTest(test.TestCase):
def testGetVariableGivenNameScoped(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -734,7 +734,7 @@ class GetVariablesBySuffixTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables_by_suffix('b'))
def testGetVariableWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
@@ -748,7 +748,7 @@ class GetVariablesBySuffixTest(test.TestCase):
self.assertEquals([a, fooa], matched_variables)
def testGetVariableWithoutScope(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
b_a = variables_lib2.variable('B/a', [5])
@@ -761,7 +761,7 @@ class GetVariablesBySuffixTest(test.TestCase):
class GetVariablesByNameTest(test.TestCase):
def testGetVariableGivenNameScoped(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -769,7 +769,7 @@ class GetVariablesByNameTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables_by_name('b'))
def testGetVariableWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
@@ -785,7 +785,7 @@ class GetVariablesByNameTest(test.TestCase):
self.assertEquals([a], matched_variables)
def testGetVariableWithoutScope(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
b_a = variables_lib2.variable('B/a', [5])
@@ -818,7 +818,7 @@ class AssignFromValuesTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
var0 = variables_lib2.variable(
'my_var0', shape=[1, 3, 1], initializer=initializer)
@@ -844,7 +844,7 @@ class AssignFromValuesTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -879,7 +879,7 @@ class AssignFromValuesFnTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
var0 = variables_lib2.variable(
'my_var0', shape=[1, 3, 1], initializer=initializer)
@@ -904,7 +904,7 @@ class AssignFromValuesFnTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -968,7 +968,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -998,7 +998,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case.
var_names_to_values = {'var0': init_value0, 'var1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
# var0 and var1 are PartitionedVariables.
@@ -1039,7 +1039,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session():
+ with self.cached_session():
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1062,7 +1062,7 @@ class AssignFromCheckpointTest(test.TestCase):
var_names_to_values = {'layer0/v0': init_value0, 'layer1/v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -1123,7 +1123,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1154,7 +1154,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[2, 1])
@@ -1183,7 +1183,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[2, 1])
@@ -1213,7 +1213,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1241,7 +1241,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('v0', shape=[])
@@ -1272,7 +1272,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1299,7 +1299,7 @@ class ZeroInitializerOpTest(test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
var = variables_lib.Variable(initializer)
var_zero = variables_lib2.zero_initializer(var)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError('Attempting to use uninitialized value'):
var.eval()
if use_init:
@@ -1324,7 +1324,7 @@ class ZeroVarInitializerOpTest(test.TestCase):
var = resource_variable_ops.ResourceVariable(initializer)
var_zero = variables_lib2.zero_initializer(var)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError('Error while reading resource variable'):
var.eval()
if use_init: