aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2b9c62ad6f..2e7975667c 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -42,7 +42,7 @@ from tensorflow.python.util import compat
class VariablesTestCase(test.TestCase):
def testInitialization(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
@@ -69,7 +69,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(1.1, var1.eval())
def testInitializationOrder(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
self.assertEqual("rnd:0", rnd.name)
self.assertEqual([3, 6], rnd.get_shape())
@@ -106,7 +106,7 @@ class VariablesTestCase(test.TestCase):
pass
def testAssignments(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(0.0)
plus_one = var.assign_add(1.0)
minus_one = var.assign_sub(2.0)
@@ -142,7 +142,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(4.0, var.eval())
def testZeroSizeStringAssign(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
array = variables.Variable(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
@@ -154,7 +154,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], list(sess.run(copy_op)))
def _countUpToTest(self, dtype):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0, dtype=dtype)
var = variables.Variable(zero)
count_up_to = var.count_up_to(3)
@@ -186,7 +186,7 @@ class VariablesTestCase(test.TestCase):
self._countUpToTest(dtypes.int64)
def testControlDepsNone(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dep.
@@ -199,7 +199,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access
def testControlFlow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(0, name="v0")
var_dict = {}
@@ -248,7 +248,7 @@ class VariablesTestCase(test.TestCase):
control_flow_ops.while_loop(cond, body, [0, 0])
def testUseVariableAsTensor(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(3.0)
variables.global_variables_initializer().run()
@@ -257,7 +257,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval())
def testZeroSizeVarSameAsConst(self):
- with self.test_session():
+ with self.cached_session():
zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
zero_size_const = array_ops.ones([2, 0])
variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
@@ -269,7 +269,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
def testCachingDevice(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(2.0)
self.assertEqual(var.device, var.value().device)
self.assertEqual(var.device, var.initialized_value().device)
@@ -279,7 +279,7 @@ class VariablesTestCase(test.TestCase):
self.assertTrue(var_cached.value().device.startswith("/job:foo"))
def testCollections(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(2.0, trainable=False)
var_z = variables.Variable(2.0, trainable=True)
@@ -294,7 +294,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
def testCollectionsWithScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("scope_1"):
var_x = variables.Variable(2.0)
with ops.name_scope("scope_2"):
@@ -309,7 +309,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_y], variables.trainable_variables("scope_2"))
def testOperators(self):
- with self.test_session():
+ with self.cached_session():
var_f = variables.Variable([2.0])
add = var_f + 0.0
radd = 1.0 + var_f
@@ -382,13 +382,13 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval())
def testSession(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1, 12])
variables.global_variables_initializer().run()
self.assertAllClose([1, 12], sess.run(var))
def testDevicePlacement(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("/cpu:0"):
var = variables.Variable([1, 12])
init_value = var.initialized_value()
@@ -408,7 +408,7 @@ class VariablesTestCase(test.TestCase):
def testInitializerFunction(self):
value = [[-42], [133.7]]
shape = [2, 1]
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(value)
v1 = variables.Variable(initializer, dtype=dtypes.float32)
@@ -443,7 +443,7 @@ class VariablesTestCase(test.TestCase):
constraint=constraint)
def testNoRefDataRace(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
b = variables.Variable(a.initialized_value() + 2)
c = variables.Variable(b.initialized_value() + 2)
@@ -453,7 +453,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllEqual(c.eval(), [5, 6, 7])
def testInitializerFunctionDevicePlacement(self):
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(42.0)
with ops.device("/cpu:100"):
v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
@@ -471,11 +471,11 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(expected_group_v2, i.op.colocation_groups())
def testVariableDefInitializedInstances(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v_def = variables.Variable(
initial_value=constant_op.constant(3.0)).to_proto()
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = variables.Variable(variable_def=v_def)
self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -486,7 +486,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# Restoring a legacy VariableDef proto that does not have
# initial_value_name set should still work.
v = variables.Variable(variable_def=v_def)
@@ -514,7 +514,7 @@ class VariablesTestCase(test.TestCase):
.trainable)
def testLoad(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
variables.global_variables_initializer().run()
var.load(np.ones((5, 5), np.float32))
@@ -540,12 +540,12 @@ class VariablesTestCase(test.TestCase):
class IsInitializedTest(test.TestCase):
def testNoVars(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
uninited = variables.report_uninitialized_variables()
self.assertEqual(0, sess.run(uninited).size)
def testAssertVariablesInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
_ = v, w
@@ -555,7 +555,7 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
@@ -566,14 +566,14 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testZeroSizeVarInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable(array_ops.zeros([0, 2]), name="v")
uninited = variables.report_uninitialized_variables()
v.initializer.run() # not strictly necessary
self.assertEqual(0, sess.run(uninited).size)
def testTrainingWithZeroSizeVar(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
a = variables.Variable(array_ops.zeros([0, 2]))
b = variables.Variable(array_ops.ones([2, 2]))
objective = math_ops.reduce_sum(b + math_ops.matmul(
@@ -592,7 +592,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
self.assertEqual(None, variables.assert_variables_initialized())
def testVariables(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
_ = v, w
@@ -603,7 +603,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
sess.run(inited)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
inited = variables.assert_variables_initialized([v])