aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/partitioned_variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py40
1 files changed, 20 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 15d5702252..b34d30f5c0 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -39,7 +39,7 @@ from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
def testFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable(
@@ -50,7 +50,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, (5, 1))
def testFixedSizePartitionerInt64(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20])
@@ -58,7 +58,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertEqual(len(v0_list), 4)
def testResourceFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope(
"root", partitioner=partitioner, use_resource=True):
@@ -88,7 +88,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testVariableAxisSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Create a partitioned variable of shape (4, 8, 16, 32) type float32
# Bytes per slice along the given axes:
@@ -210,7 +210,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testMinMaxVariablePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
self._testMinMaxVariablePartitioner(
max_partitions=100,
@@ -323,7 +323,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec)
def testVecConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([1, 2, 3, 4])
vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
variables.global_variables_initializer().run()
@@ -334,7 +334,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
def testConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
rnd_par)
@@ -346,7 +346,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
def _testNameHelper(self, use_resource=False):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope("hi", use_resource=use_resource):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -363,7 +363,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test same variable.
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope(
"hola", use_resource=use_resource) as vs:
@@ -383,7 +383,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test name_scope
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with ops.name_scope("ola"):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -408,7 +408,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._testNameHelper(use_resource=True)
def testRandomInitValue(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([200, 40]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 10], rnd.initialized_value())
@@ -425,7 +425,7 @@ class PartitionedVariablesTestCase(test.TestCase):
])
def testRandomInitUnevenPartitions(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(
random_ops.random_uniform([20, 43], dtype=dtypes.float64))
var_lists = [
@@ -463,7 +463,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, save_specs[i])
def testDegenerate(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 1], rnd.initialized_value())
@@ -474,7 +474,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
def testSliceSizeOne(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [10, 1], rnd.initialized_value())
@@ -492,7 +492,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
_IotaInitializer([4, 2]))
- with self.test_session():
+ with self.cached_session():
vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
_IotaInitializer)
variables.global_variables_initializer().run()
@@ -506,7 +506,7 @@ class PartitionedVariablesTestCase(test.TestCase):
def testRandomInitializer(self):
# Sanity check that the slices uses a different seed when using a random
# initializer function.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer())
variables.global_variables_initializer().run()
@@ -514,7 +514,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
# Negative test that proves that slices have the same values if
# the random initializer uses a seed.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
variables.global_variables_initializer().run()
@@ -522,7 +522,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose(val0, val1)
def testSomeErrors(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
with self.assertRaises(ValueError):
partitioned_variables.create_partitioned_variables(
@@ -547,7 +547,7 @@ class PartitionedVariablesTestCase(test.TestCase):
[10, 43], [1, 50], rnd.initialized_value())
def testControlDepsNone(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dependency.
@@ -573,7 +573,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual([], op.control_inputs)
def testConcat(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
var_x = variable_scope.get_variable(
"x",
initializer=constant_op.constant([1., 2.]),