diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-21 19:09:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 19:13:28 -0700 |
commit | 754fffb399efa6204bb8aae51ce99042cb2ab18e (patch) | |
tree | 3f3a3ecd5e25bac3a4babd9ca330f63d21fb2918 /tensorflow/contrib/framework | |
parent | 34f07dc58afcbddf3c4387cdf7c49ebb5aacf4dd (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 209700634
Diffstat (limited to 'tensorflow/contrib/framework')
5 files changed, 93 insertions, 93 deletions
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index bcafc1a328..0e6c6f0e2f 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -52,7 +52,7 @@ def _key_op(op): class ArgScopeTest(test.TestCase): def testEmptyArgScope(self): - with self.test_session(): + with self.cached_session(): with arg_scope([]) as sc: self.assertEqual(sc, {}) @@ -60,7 +60,7 @@ class ArgScopeTest(test.TestCase): func1_kwargs = {'a': 1, 'b': None, 'c': [1]} key_op = _key_op(func1) func1_scope = {key_op: func1_kwargs.copy()} - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]) as sc1: self.assertEqual(sc1, func1_scope) with arg_scope({}) as sc2: @@ -86,7 +86,7 @@ class ArgScopeTest(test.TestCase): func1_kwargs = {'a': 1, 'b': None, 'c': [1]} key_op = _key_op(func1) current_scope = {key_op: func1_kwargs.copy()} - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]) as scope: self.assertDictEqual(scope, current_scope) @@ -102,7 +102,7 @@ class ArgScopeTest(test.TestCase): key(func1): func1_kwargs.copy(), key(func2): func2_kwargs.copy() } - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]): with arg_scope([func2], b=2, d=[2]) as scope: self.assertDictEqual(scope, current_scope) @@ -111,7 +111,7 @@ class ArgScopeTest(test.TestCase): func1_kwargs = {'a': 1, 'b': None, 'c': [1]} key_op = _key_op(func1) current_scope = {key_op: func1_kwargs.copy()} - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]) as scope1: pass with arg_scope(scope1) as scope: @@ -126,7 +126,7 @@ class ArgScopeTest(test.TestCase): key(func1): func1_kwargs.copy(), key(func2): func2_kwargs.copy() } - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]) as scope1: with arg_scope([func2], b=2, d=[2]) as scope2: pass @@ -140,7 +140,7 @@ class ArgScopeTest(test.TestCase): def testSimpleArgScope(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} - with self.test_session(): + with self.cached_session(): with arg_scope([func1], a=1, b=None, c=[1]): args, kwargs = func1(0) self.assertTupleEqual(args, func1_args) @@ -149,7 +149,7 @@ class ArgScopeTest(test.TestCase): def testSimpleArgScopeWithTuple(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} - with self.test_session(): + with self.cached_session(): with arg_scope((func1,), a=1, b=None, c=[1]): args, kwargs = func1(0) self.assertTupleEqual(args, func1_args) @@ -240,7 +240,7 @@ class ArgScopeTest(test.TestCase): def testAddArgScopeRaceCondition(self): func4_kwargs = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') for i in range(4): - # redefine the function with different args + # redefine the function with different args @add_arg_scope def func4(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8): pass diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py index b7b9f5c59e..4036c87b6d 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py @@ -50,7 +50,7 @@ class LoadMulticlassBiasTest(test.TestCase): bias = variables.Variable( array_ops.reshape(flat_data, (num, dim)), name='bias') save = saver.Saver([bias]) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() self.bundle_file = os.path.join(test.get_temp_dir(), 'bias_checkpoint') save.save(sess, self.bundle_file) @@ -90,7 +90,7 @@ class LoadMulticlassBiasTest(test.TestCase): initializer=bias_loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(3)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_bias_vector, remapped_bias_vector.as_tensor().eval()) @@ -109,7 +109,7 @@ class LoadVariableSlotTest(test.TestCase): accum = variables.Variable( array_ops.reshape(flat_data, (num, dim)), name='accum') save = saver.Saver([accum]) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() self.bundle_file = os.path.join(test.get_temp_dir(), 'accum_checkpoint') save.save(sess, self.bundle_file) @@ -179,7 +179,7 @@ class LoadVariableSlotTest(test.TestCase): shape=[2, 1], initializer=variable_slot_initializer_part_1) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_accum_vector_part_0, remapped_accum_vector_part_0.eval()) diff --git a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py index 50bcbe625d..c104c51fef 100644 --- a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py @@ -34,7 +34,7 @@ class PrettyPrintOpsTest(test.TestCase): def testPrintTensorPassthrough(self): a = constant_op.constant([1]) a = prettyprint_ops.print_op(a) - with self.test_session(): + with self.cached_session(): self.assertEqual(a.eval(), constant_op.constant([1]).eval()) def testPrintSparseTensorPassthrough(self): @@ -43,7 +43,7 @@ class PrettyPrintOpsTest(test.TestCase): b = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) a = prettyprint_ops.print_op(a) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( sparse_ops.sparse_tensor_to_dense(a).eval(), sparse_ops.sparse_tensor_to_dense(b).eval()) @@ -54,13 +54,13 @@ class PrettyPrintOpsTest(test.TestCase): a = a.write(1, 1) a = a.write(0, 0) a = prettyprint_ops.print_op(a) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(a.stack().eval(), constant_op.constant([0, 1]).eval()) def testPrintVariable(self): a = variables.Variable(1.0) a = prettyprint_ops.print_op(a) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() a.eval() diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index a8fb94b245..791b32cd1e 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -48,7 +48,7 @@ class SortTest(test.TestCase): sort_axis = np.random.choice(rank) if negative_axis: sort_axis = -1 - sort_axis - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=sort_axis), sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) @@ -60,7 +60,7 @@ class SortTest(test.TestCase): shape = [np.random.randint(1, 4) for _ in range(rank)] arr = np.random.random(shape) sort_axis = np.random.choice(rank) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=sort_axis), sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) @@ -73,7 +73,7 @@ class SortTest(test.TestCase): scalar = array_ops.zeros(zeros_length_1) sort = sort_ops.sort(scalar) - with self.test_session(): + with self.cached_session(): with self.assertRaises(errors.InvalidArgumentError): sort.eval() @@ -84,7 +84,7 @@ class SortTest(test.TestCase): def testDescending(self): arr = np.random.random((10, 5, 5)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr, axis=0)[::-1], sort_ops.sort( @@ -111,7 +111,7 @@ class SortTest(test.TestCase): def testArgsort_1d(self): arr = np.random.random(42) - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.sort(arr), array_ops.gather(arr, sort_ops.argsort(arr)).eval()) @@ -119,7 +119,7 @@ class SortTest(test.TestCase): def testArgsort(self): arr = np.random.random((5, 6, 7, 8)) for axis in range(4): - with self.test_session(): + with self.cached_session(): self.assertAllEqual( np.argsort(arr, axis=axis), sort_ops.argsort(arr, axis=axis).eval()) diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 3c44630a51..f9b0efd1da 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -45,7 +45,7 @@ from tensorflow.python.training import saver as saver_lib class LocalVariableTest(test.TestCase): def test_local_variable(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEquals([], variables_lib.local_variables()) value0 = 42 variables_lib2.local_variable(value0) @@ -58,7 +58,7 @@ class LocalVariableTest(test.TestCase): self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) def testLocalVariableNameAndShape(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.local_variable([1, 1, 1, 1, 1], name='a') self.assertEquals(a.op.name, 'A/a') @@ -66,21 +66,21 @@ class LocalVariableTest(test.TestCase): self.assertListEqual([a], variables_lib2.get_local_variables()) def testLocalVariableNotInAllVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.local_variable(0) self.assertFalse(a in variables_lib.global_variables()) self.assertTrue(a in variables_lib.local_variables()) def testLocalVariableNotInVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.local_variable(0) self.assertFalse(a in variables_lib2.get_variables_to_restore()) self.assertTrue(a in variables_lib.local_variables()) def testGetVariablesDontReturnsTransients(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): variables_lib2.local_variable(0) with variable_scope.variable_scope('B'): @@ -89,7 +89,7 @@ class LocalVariableTest(test.TestCase): self.assertEquals([], variables_lib2.get_variables('B')) def testGetLocalVariablesReturnsTransients(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.local_variable(0) with variable_scope.variable_scope('B'): @@ -98,7 +98,7 @@ class LocalVariableTest(test.TestCase): self.assertEquals([b], variables_lib2.get_local_variables('B')) def testInitializedVariableValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = variables_lib2.local_variable([0, 0, 0, 0, 0], name='a') sess.run(variables_lib.local_variables_initializer()) self.assertAllEqual(a.eval(), [0] * 5) @@ -114,7 +114,7 @@ class LocalVariableTest(test.TestCase): class GlobalVariableTest(test.TestCase): def test_global_variable(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEquals([], variables_lib.global_variables()) value0 = 42 variables_lib2.global_variable(value0) @@ -129,7 +129,7 @@ class GlobalVariableTest(test.TestCase): self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) def testVariableNameAndShape(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a') self.assertEquals(a.op.name, 'A/a') @@ -137,21 +137,21 @@ class GlobalVariableTest(test.TestCase): self.assertListEqual([a], variables_lib.global_variables()) def testGlobalVariableNotInLocalVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.global_variable(0) self.assertFalse(a in variables_lib.local_variables()) self.assertTrue(a in variables_lib.global_variables()) def testGlobalVariableInVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.global_variable(0) self.assertFalse(a in variables_lib.local_variables()) self.assertTrue(a in variables_lib2.get_variables_to_restore()) def testGetVariablesReturnsThem(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.global_variable(0) with variable_scope.variable_scope('B'): @@ -160,7 +160,7 @@ class GlobalVariableTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables('B')) def testGetLocalVariablesDontReturnsThem(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): variables_lib2.global_variable(0) with variable_scope.variable_scope('B'): @@ -169,7 +169,7 @@ class GlobalVariableTest(test.TestCase): self.assertEquals([], variables_lib2.get_local_variables('B')) def testInitializedVariableValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a') sess.run(variables_lib.global_variables_initializer()) self.assertAllEqual(a.eval(), [0] * 5) @@ -249,7 +249,7 @@ class GlobalStepTest(test.TestCase): class VariablesTest(test.TestCase): def testCreateVariable(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) self.assertEquals(a.op.name, 'A/a') @@ -259,7 +259,7 @@ class VariablesTest(test.TestCase): self.assertFalse(a in variables_lib.local_variables()) def testGetVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): @@ -269,7 +269,7 @@ class VariablesTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables('B')) def testGetVariablesWithScope(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A') as var_scope: a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) @@ -277,7 +277,7 @@ class VariablesTest(test.TestCase): set([a, b]), set(variables_lib2.get_variables(var_scope))) def testGetVariablesSuffix(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('A'): @@ -286,13 +286,13 @@ class VariablesTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables(suffix='b')) def testGetVariableWithSingleVar(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('parent'): a = variables_lib2.variable('child', [5]) self.assertEquals(a, variables_lib2.get_unique_variable('parent/child')) def testGetVariableWithDistractors(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('parent'): a = variables_lib2.variable('child', [5]) with variable_scope.variable_scope('child'): @@ -302,13 +302,13 @@ class VariablesTest(test.TestCase): def testGetVariableThrowsExceptionWithNoMatch(self): var_name = 'cant_find_me' - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): variables_lib2.get_unique_variable(var_name) def testGetThrowsExceptionWithChildrenButNoMatch(self): var_name = 'parent/child' - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope(var_name): variables_lib2.variable('grandchild1', [7]) variables_lib2.variable('grandchild2', [9]) @@ -316,7 +316,7 @@ class VariablesTest(test.TestCase): variables_lib2.get_unique_variable(var_name) def testGetVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): @@ -324,7 +324,7 @@ class VariablesTest(test.TestCase): self.assertEquals([a, b], variables_lib2.get_variables_to_restore()) def testIncludeGetVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): @@ -333,7 +333,7 @@ class VariablesTest(test.TestCase): self.assertEquals([a], variables_lib2.get_variables_to_restore(['A'])) def testExcludeGetVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): @@ -343,7 +343,7 @@ class VariablesTest(test.TestCase): [a], variables_lib2.get_variables_to_restore(exclude=['B'])) def testWrongIncludeGetVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): @@ -352,7 +352,7 @@ class VariablesTest(test.TestCase): self.assertEquals([], variables_lib2.get_variables_to_restore(['a'])) def testGetMixedVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) @@ -365,7 +365,7 @@ class VariablesTest(test.TestCase): variables_lib2.get_variables_to_restore(include=['A/a', 'B/c'])) def testExcludeGetMixedVariablesToRestore(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) @@ -378,7 +378,7 @@ class VariablesTest(test.TestCase): variables_lib2.get_variables_to_restore(exclude=['A/a', 'B/c'])) def testReuseVariable(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', []) with variable_scope.variable_scope('A', reuse=True): @@ -387,14 +387,14 @@ class VariablesTest(test.TestCase): self.assertListEqual([a], variables_lib2.get_variables()) def testVariableWithRegularizer(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [], regularizer=nn_ops.l2_loss) loss = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertDeviceEqual(loss.device, a.device) def testVariableWithRegularizerColocate(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable( 'a', [], device='gpu:0', regularizer=nn_ops.l2_loss) @@ -402,7 +402,7 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(loss.device, a.device) def testVariableWithDevice(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [], device='cpu:0') b = variables_lib2.variable('b', [], device='cpu:1') @@ -410,7 +410,7 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(b.device, 'cpu:1') def testVariableWithDeviceFromScope(self): - with self.test_session(): + with self.cached_session(): with ops.device('/cpu:0'): a = variables_lib2.variable('a', []) b = variables_lib2.variable('b', [], device='cpu:1') @@ -428,7 +428,7 @@ class VariablesTest(test.TestCase): self.counter += 1 return 'cpu:%d' % self.counter - with self.test_session(): + with self.cached_session(): with arg_scope([variables_lib2.variable], device=DevFn()): a = variables_lib2.variable('a', []) b = variables_lib2.variable('b', []) @@ -453,7 +453,7 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.initial_value.device, 'cpu:99') def testVariableWithReplicaDeviceSetter(self): - with self.test_session(): + with self.cached_session(): with ops.device(device_setter.replica_device_setter(ps_tasks=2)): a = variables_lib2.variable('a', []) b = variables_lib2.variable('b', []) @@ -570,7 +570,7 @@ class VariablesTest(test.TestCase): class ModelVariablesTest(test.TestCase): def testNameAndShape(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.model_variable('a', [5]) self.assertEquals(a.op.name, 'A/a') @@ -578,7 +578,7 @@ class ModelVariablesTest(test.TestCase): self.assertListEqual([a], variables_lib2.get_model_variables('A')) def testNotInLocalVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.model_variable('a', [5]) self.assertTrue(a in variables_lib.global_variables()) @@ -586,7 +586,7 @@ class ModelVariablesTest(test.TestCase): self.assertFalse(a in variables_lib.local_variables()) def testGetVariablesReturns(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.model_variable('a', [5]) with variable_scope.variable_scope('B'): @@ -595,7 +595,7 @@ class ModelVariablesTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables('B')) def testGetModelVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.model_variable('a', [5]) with variable_scope.variable_scope('B'): @@ -604,7 +604,7 @@ class ModelVariablesTest(test.TestCase): self.assertEquals([b], variables_lib2.get_model_variables('B')) def testGetTrainableVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): variables_lib2.local_variable([5]) a = variables_lib.Variable([5]) @@ -615,7 +615,7 @@ class ModelVariablesTest(test.TestCase): self.assertEquals([b], variables_lib2.get_trainable_variables('B')) def testGetLocalVariables(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): _ = variables_lib2.model_variable('a', [5]) with variable_scope.variable_scope('B'): @@ -624,7 +624,7 @@ class ModelVariablesTest(test.TestCase): self.assertEquals([], variables_lib2.get_local_variables('B')) def testInitializedVariableValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = variables_lib2.model_variable( 'a', [5], initializer=init_ops.ones_initializer()) sess.run(variables_lib.global_variables_initializer()) @@ -670,14 +670,14 @@ class ModelVariablesTest(test.TestCase): class GetVariablesCollections(test.TestCase): def testVariableCollection(self): - with self.test_session(): + with self.cached_session(): a = variables_lib2.variable('a', [], collections='A') b = variables_lib2.variable('b', [], collections='B') self.assertEquals(a, ops.get_collection('A')[0]) self.assertEquals(b, ops.get_collection('B')[0]) def testVariableCollections(self): - with self.test_session(): + with self.cached_session(): a = variables_lib2.variable('a', [], collections=['A', 'C']) b = variables_lib2.variable('b', [], collections=['B', 'C']) self.assertEquals(a, ops.get_collection('A')[0]) @@ -685,14 +685,14 @@ class GetVariablesCollections(test.TestCase): self.assertListEqual([a, b], ops.get_collection('C')) def testVariableCollectionsWithArgScope(self): - with self.test_session(): + with self.cached_session(): with arg_scope([variables_lib2.variable], collections='A'): a = variables_lib2.variable('a', []) b = variables_lib2.variable('b', []) self.assertListEqual([a, b], ops.get_collection('A')) def testVariableCollectionsWithArgScopeNested(self): - with self.test_session(): + with self.cached_session(): with arg_scope([variables_lib2.variable], collections='A'): a = variables_lib2.variable('a', []) with arg_scope([variables_lib2.variable], collections='B'): @@ -701,7 +701,7 @@ class GetVariablesCollections(test.TestCase): self.assertEquals(b, ops.get_collection('B')[0]) def testVariableCollectionsWithArgScopeNonNested(self): - with self.test_session(): + with self.cached_session(): with arg_scope([variables_lib2.variable], collections='A'): a = variables_lib2.variable('a', []) with arg_scope([variables_lib2.variable], collections='B'): @@ -711,7 +711,7 @@ class GetVariablesCollections(test.TestCase): self.assertListEqual([b], ops.get_collection('B')) def testVariableRestoreWithArgScopeNested(self): - with self.test_session(): + with self.cached_session(): a = variables_lib2.variable('a', []) with arg_scope( [variables_lib2.variable], trainable=False, collections=['A', 'B']): @@ -726,7 +726,7 @@ class GetVariablesCollections(test.TestCase): class GetVariablesBySuffixTest(test.TestCase): def testGetVariableGivenNameScoped(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) @@ -734,7 +734,7 @@ class GetVariablesBySuffixTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables_by_suffix('b')) def testGetVariableWithScope(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) fooa = variables_lib2.variable('fooa', [5]) @@ -748,7 +748,7 @@ class GetVariablesBySuffixTest(test.TestCase): self.assertEquals([a, fooa], matched_variables) def testGetVariableWithoutScope(self): - with self.test_session(): + with self.cached_session(): a = variables_lib2.variable('a', [5]) fooa = variables_lib2.variable('fooa', [5]) b_a = variables_lib2.variable('B/a', [5]) @@ -761,7 +761,7 @@ class GetVariablesBySuffixTest(test.TestCase): class GetVariablesByNameTest(test.TestCase): def testGetVariableGivenNameScoped(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) @@ -769,7 +769,7 @@ class GetVariablesByNameTest(test.TestCase): self.assertEquals([b], variables_lib2.get_variables_by_name('b')) def testGetVariableWithScope(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) fooa = variables_lib2.variable('fooa', [5]) @@ -785,7 +785,7 @@ class GetVariablesByNameTest(test.TestCase): self.assertEquals([a], matched_variables) def testGetVariableWithoutScope(self): - with self.test_session(): + with self.cached_session(): a = variables_lib2.variable('a', [5]) fooa = variables_lib2.variable('fooa', [5]) b_a = variables_lib2.variable('B/a', [5]) @@ -818,7 +818,7 @@ class AssignFromValuesTest(test.TestCase): init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1)) init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2)) - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.truncated_normal_initializer(stddev=.1) var0 = variables_lib2.variable( 'my_var0', shape=[1, 3, 1], initializer=initializer) @@ -844,7 +844,7 @@ class AssignFromValuesTest(test.TestCase): init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1)) init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2)) - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.truncated_normal_initializer(stddev=.1) with variable_scope.variable_scope('my_model/my_layer0'): @@ -879,7 +879,7 @@ class AssignFromValuesFnTest(test.TestCase): init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1)) init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2)) - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.truncated_normal_initializer(stddev=.1) var0 = variables_lib2.variable( 'my_var0', shape=[1, 3, 1], initializer=initializer) @@ -904,7 +904,7 @@ class AssignFromValuesFnTest(test.TestCase): init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1)) init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2)) - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.truncated_normal_initializer(stddev=.1) with variable_scope.variable_scope('my_model/my_layer0'): @@ -968,7 +968,7 @@ class AssignFromCheckpointTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) @@ -998,7 +998,7 @@ class AssignFromCheckpointTest(test.TestCase): init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. var_names_to_values = {'var0': init_value0, 'var1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) # var0 and var1 are PartitionedVariables. @@ -1039,7 +1039,7 @@ class AssignFromCheckpointTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session(): + with self.cached_session(): model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) @@ -1062,7 +1062,7 @@ class AssignFromCheckpointTest(test.TestCase): var_names_to_values = {'layer0/v0': init_value0, 'layer1/v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) with variable_scope.variable_scope('my_model/my_layer0'): @@ -1123,7 +1123,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) @@ -1154,7 +1154,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[2, 1]) @@ -1183,7 +1183,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[2, 1]) @@ -1213,7 +1213,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) @@ -1241,7 +1241,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('v0', shape=[]) @@ -1272,7 +1272,7 @@ class AssignFromCheckpointFnTest(test.TestCase): init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} - with self.test_session() as sess: + with self.cached_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) @@ -1299,7 +1299,7 @@ class ZeroInitializerOpTest(test.TestCase): def _testZeroInitializer(self, shape, initializer, use_init): var = variables_lib.Variable(initializer) var_zero = variables_lib2.zero_initializer(var) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError('Attempting to use uninitialized value'): var.eval() if use_init: @@ -1324,7 +1324,7 @@ class ZeroVarInitializerOpTest(test.TestCase): var = resource_variable_ops.ResourceVariable(initializer) var_zero = variables_lib2.zero_initializer(var) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError('Error while reading resource variable'): var.eval() if use_init: |