diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/partitioned_variables_test.py | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index 15d5702252..b34d30f5c0 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -39,7 +39,7 @@ from tensorflow.python.training import saver as saver_lib class PartitionerCreatorsTest(test.TestCase): def testFixedSizePartitioner(self): - with self.test_session(): + with self.cached_session(): partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) with variable_scope.variable_scope("root", partitioner=partitioner): v0 = variable_scope.get_variable( @@ -50,7 +50,7 @@ class PartitionerCreatorsTest(test.TestCase): self.assertAllEqual(v0_part, (5, 1)) def testFixedSizePartitionerInt64(self): - with self.test_session(): + with self.cached_session(): partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0) with variable_scope.variable_scope("root", partitioner=partitioner): v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20]) @@ -58,7 +58,7 @@ class PartitionerCreatorsTest(test.TestCase): self.assertEqual(len(v0_list), 4) def testResourceFixedSizePartitioner(self): - with self.test_session(): + with self.cached_session(): partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) with variable_scope.variable_scope( "root", partitioner=partitioner, use_resource=True): @@ -88,7 +88,7 @@ class PartitionerCreatorsTest(test.TestCase): self.assertAllEqual(v0_part, expected_partitions) def testVariableAxisSizePartitioner(self): - with self.test_session(): + with self.cached_session(): # Create a partitioned variable of shape (4, 8, 16, 32) type float32 # Bytes per slice along the given axes: @@ -210,7 +210,7 @@ class PartitionerCreatorsTest(test.TestCase): self.assertAllEqual(v0_part, expected_partitions) def testMinMaxVariablePartitioner(self): - with self.test_session(): + with self.cached_session(): # Partitioning a variable of shape=[2048] with a minimum of 2K per slice. self._testMinMaxVariablePartitioner( max_partitions=100, @@ -323,7 +323,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec) def testVecConstantInit(self): - with self.test_session(): + with self.cached_session(): rnd_par = constant_op.constant([1, 2, 3, 4]) vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par) variables.global_variables_initializer().run() @@ -334,7 +334,7 @@ class PartitionedVariablesTestCase(test.TestCase): self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"]) def testConstantInit(self): - with self.test_session(): + with self.cached_session(): rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], rnd_par) @@ -346,7 +346,7 @@ class PartitionedVariablesTestCase(test.TestCase): self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"]) def _testNameHelper(self, use_resource=False): - with self.test_session(): + with self.cached_session(): rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) with variable_scope.variable_scope("hi", use_resource=use_resource): vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], @@ -363,7 +363,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertEqual(var2_name + "/part_0:0", vs2[0].name) self.assertEqual(var2_name + "/part_1:0", vs2[1].name) # Test same variable. - with self.test_session(): + with self.cached_session(): rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) with variable_scope.variable_scope( "hola", use_resource=use_resource) as vs: @@ -383,7 +383,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertEqual(var2_name + "/part_0:0", vs2[0].name) self.assertEqual(var2_name + "/part_1:0", vs2[1].name) # Test name_scope - with self.test_session(): + with self.cached_session(): rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) with ops.name_scope("ola"): vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], @@ -408,7 +408,7 @@ class PartitionedVariablesTestCase(test.TestCase): self._testNameHelper(use_resource=True) def testRandomInitValue(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable(random_ops.random_uniform([200, 40])) vs = partitioned_variables.create_partitioned_variables( rnd.get_shape(), [1, 10], rnd.initialized_value()) @@ -425,7 +425,7 @@ class PartitionedVariablesTestCase(test.TestCase): ]) def testRandomInitUnevenPartitions(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable( random_ops.random_uniform([20, 43], dtype=dtypes.float64)) var_lists = [ @@ -463,7 +463,7 @@ class PartitionedVariablesTestCase(test.TestCase): self._TestSaveSpec(vs, save_specs[i]) def testDegenerate(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable(random_ops.random_uniform([10, 43])) vs = partitioned_variables.create_partitioned_variables( rnd.get_shape(), [1, 1], rnd.initialized_value()) @@ -474,7 +474,7 @@ class PartitionedVariablesTestCase(test.TestCase): self._TestSaveSpec(vs, ["10 43 0,10:0,43"]) def testSliceSizeOne(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable(random_ops.random_uniform([10, 43])) vs = partitioned_variables.create_partitioned_variables( rnd.get_shape(), [10, 1], rnd.initialized_value()) @@ -492,7 +492,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4])) self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]], _IotaInitializer([4, 2])) - with self.test_session(): + with self.cached_session(): vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1], _IotaInitializer) variables.global_variables_initializer().run() @@ -506,7 +506,7 @@ class PartitionedVariablesTestCase(test.TestCase): def testRandomInitializer(self): # Sanity check that the slices uses a different seed when using a random # initializer function. - with self.test_session(): + with self.cached_session(): var0, var1 = partitioned_variables.create_partitioned_variables( [20, 12], [1, 2], init_ops.random_uniform_initializer()) variables.global_variables_initializer().run() @@ -514,7 +514,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6) # Negative test that proves that slices have the same values if # the random initializer uses a seed. - with self.test_session(): + with self.cached_session(): var0, var1 = partitioned_variables.create_partitioned_variables( [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201)) variables.global_variables_initializer().run() @@ -522,7 +522,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertAllClose(val0, val1) def testSomeErrors(self): - with self.test_session(): + with self.cached_session(): rnd = variables.Variable(random_ops.random_uniform([10, 43])) with self.assertRaises(ValueError): partitioned_variables.create_partitioned_variables( @@ -547,7 +547,7 @@ class PartitionedVariablesTestCase(test.TestCase): [10, 43], [1, 50], rnd.initialized_value()) def testControlDepsNone(self): - with self.test_session() as session: + with self.cached_session() as session: c = constant_op.constant(1.0) with ops.control_dependencies([c]): # d get the control dependency. @@ -573,7 +573,7 @@ class PartitionedVariablesTestCase(test.TestCase): self.assertEqual([], op.control_inputs) def testConcat(self): - with self.test_session() as session: + with self.cached_session() as session: var_x = variable_scope.get_variable( "x", initializer=constant_op.constant([1., 2.]), |