diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 58 |
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 2b9c62ad6f..2e7975667c 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -42,7 +42,7 @@ from tensorflow.python.util import compat class VariablesTestCase(test.TestCase): def testInitialization(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable(0.0) self.assertEqual("Variable:0", var0.name) self.assertEqual("Variable", var0._shared_name) @@ -69,7 +69,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose(1.1, var1.eval()) def testInitializationOrder(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd") self.assertEqual("rnd:0", rnd.name) self.assertEqual([3, 6], rnd.get_shape()) @@ -106,7 +106,7 @@ class VariablesTestCase(test.TestCase): pass def testAssignments(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(0.0) plus_one = var.assign_add(1.0) minus_one = var.assign_sub(2.0) @@ -142,7 +142,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose(4.0, var.eval()) def testZeroSizeStringAssign(self): - with self.test_session() as sess: + with self.cached_session() as sess: array = variables.Variable( initial_value=array_ops.zeros((0,), dtype=dtypes.string), name="foo", @@ -154,7 +154,7 @@ class VariablesTestCase(test.TestCase): self.assertEqual([], list(sess.run(copy_op))) def _countUpToTest(self, dtype): - with self.test_session(): + with self.cached_session(): zero = constant_op.constant(0, dtype=dtype) var = variables.Variable(zero) count_up_to = var.count_up_to(3) @@ -186,7 +186,7 @@ class VariablesTestCase(test.TestCase): self._countUpToTest(dtypes.int64) def testControlDepsNone(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(1.0) with ops.control_dependencies([c]): # d get the control dep. @@ -199,7 +199,7 @@ class VariablesTestCase(test.TestCase): self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access def testControlFlow(self): - with self.test_session() as sess: + with self.cached_session() as sess: v0 = variables.Variable(0, name="v0") var_dict = {} @@ -248,7 +248,7 @@ class VariablesTestCase(test.TestCase): control_flow_ops.while_loop(cond, body, [0, 0]) def testUseVariableAsTensor(self): - with self.test_session(): + with self.cached_session(): var_x = variables.Variable(2.0) var_y = variables.Variable(3.0) variables.global_variables_initializer().run() @@ -257,7 +257,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval()) def testZeroSizeVarSameAsConst(self): - with self.test_session(): + with self.cached_session(): zero_size_var = variables.Variable(array_ops.zeros([0, 2])) zero_size_const = array_ops.ones([2, 0]) variable_mul = math_ops.matmul(zero_size_const, zero_size_var) @@ -269,7 +269,7 @@ class VariablesTestCase(test.TestCase): self.assertAllClose([[0., 0.], [0., 0.]], variable_output) def testCachingDevice(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(2.0) self.assertEqual(var.device, var.value().device) self.assertEqual(var.device, var.initialized_value().device) @@ -279,7 +279,7 @@ class VariablesTestCase(test.TestCase): self.assertTrue(var_cached.value().device.startswith("/job:foo")) def testCollections(self): - with self.test_session(): + with self.cached_session(): var_x = variables.Variable(2.0) var_y = variables.Variable(2.0, trainable=False) var_z = variables.Variable(2.0, trainable=True) @@ -294,7 +294,7 @@ class VariablesTestCase(test.TestCase): self.assertEqual([var_x, var_z, var_t], variables.trainable_variables()) def testCollectionsWithScope(self): - with self.test_session(): + with self.cached_session(): with ops.name_scope("scope_1"): var_x = variables.Variable(2.0) with ops.name_scope("scope_2"): @@ -309,7 +309,7 @@ class VariablesTestCase(test.TestCase): self.assertEqual([var_y], variables.trainable_variables("scope_2")) def testOperators(self): - with self.test_session(): + with self.cached_session(): var_f = variables.Variable([2.0]) add = var_f + 0.0 radd = 1.0 + var_f @@ -382,13 +382,13 @@ class VariablesTestCase(test.TestCase): self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval()) def testSession(self): - with self.test_session() as sess: + with self.cached_session() as sess: var = variables.Variable([1, 12]) variables.global_variables_initializer().run() self.assertAllClose([1, 12], sess.run(var)) def testDevicePlacement(self): - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("/cpu:0"): var = variables.Variable([1, 12]) init_value = var.initialized_value() @@ -408,7 +408,7 @@ class VariablesTestCase(test.TestCase): def testInitializerFunction(self): value = [[-42], [133.7]] shape = [2, 1] - with self.test_session(): + with self.cached_session(): initializer = lambda: constant_op.constant(value) v1 = variables.Variable(initializer, dtype=dtypes.float32) @@ -443,7 +443,7 @@ class VariablesTestCase(test.TestCase): constraint=constraint) def testNoRefDataRace(self): - with self.test_session(): + with self.cached_session(): a = variables.Variable([1, 2, 3], dtype=dtypes.float32) b = variables.Variable(a.initialized_value() + 2) c = variables.Variable(b.initialized_value() + 2) @@ -453,7 +453,7 @@ class VariablesTestCase(test.TestCase): self.assertAllEqual(c.eval(), [5, 6, 7]) def testInitializerFunctionDevicePlacement(self): - with self.test_session(): + with self.cached_session(): initializer = lambda: constant_op.constant(42.0) with ops.device("/cpu:100"): v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1") @@ -471,11 +471,11 @@ class VariablesTestCase(test.TestCase): self.assertEqual(expected_group_v2, i.op.colocation_groups()) def testVariableDefInitializedInstances(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v_def = variables.Variable( initial_value=constant_op.constant(3.0)).to_proto() - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # v describes a VariableDef-based variable without an initial value. v = variables.Variable(variable_def=v_def) self.assertEqual(3.0, sess.run(v.initialized_value())) @@ -486,7 +486,7 @@ class VariablesTestCase(test.TestCase): self.assertEqual(1.0, v.initialized_value().eval()) v_def.ClearField("initial_value_name") - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # Restoring a legacy VariableDef proto that does not have # initial_value_name set should still work. v = variables.Variable(variable_def=v_def) @@ -514,7 +514,7 @@ class VariablesTestCase(test.TestCase): .trainable) def testLoad(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(np.zeros((5, 5), np.float32)) variables.global_variables_initializer().run() var.load(np.ones((5, 5), np.float32)) @@ -540,12 +540,12 @@ class VariablesTestCase(test.TestCase): class IsInitializedTest(test.TestCase): def testNoVars(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: uninited = variables.report_uninitialized_variables() self.assertEqual(0, sess.run(uninited).size) def testAssertVariablesInitialized(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([1, 2], name="v") w = variables.Variable([3, 4], name="w") _ = v, w @@ -555,7 +555,7 @@ class IsInitializedTest(test.TestCase): self.assertEqual(0, sess.run(uninited).size) def testVariableList(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([1, 2], name="v") w = variables.Variable([3, 4], name="w") uninited = variables.report_uninitialized_variables() @@ -566,14 +566,14 @@ class IsInitializedTest(test.TestCase): self.assertEqual(0, sess.run(uninited).size) def testZeroSizeVarInitialized(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable(array_ops.zeros([0, 2]), name="v") uninited = variables.report_uninitialized_variables() v.initializer.run() # not strictly necessary self.assertEqual(0, sess.run(uninited).size) def testTrainingWithZeroSizeVar(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: a = variables.Variable(array_ops.zeros([0, 2])) b = variables.Variable(array_ops.ones([2, 2])) objective = math_ops.reduce_sum(b + math_ops.matmul( @@ -592,7 +592,7 @@ class ObsoleteIsInitializedTest(test.TestCase): self.assertEqual(None, variables.assert_variables_initialized()) def testVariables(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([1, 2]) w = variables.Variable([3, 4]) _ = v, w @@ -603,7 +603,7 @@ class ObsoleteIsInitializedTest(test.TestCase): sess.run(inited) def testVariableList(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([1, 2]) w = variables.Variable([3, 4]) inited = variables.assert_variables_initialized([v]) |