aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/checkpoint_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/checkpoint_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/checkpoint_ops_test.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
index 7f147ba53a..51611b75af 100644
--- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py
+++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
@@ -57,7 +57,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = range(0, 3)
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -70,7 +70,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = [2, 0, 1]
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -83,7 +83,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=1)
expected_remapping = [0]
expected_num_present = 1
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -98,7 +98,7 @@ class GenerateVocabRemappingTest(test.TestCase):
old_vocab_size=2)
expected_remapping = [-1, 0, 1]
expected_num_present = 2
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -122,7 +122,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
self.old_tensor_name = 'some_scope/matrix'
save = saver.Saver([matrix])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
save.save(sess, self.bundle_file)
@@ -140,7 +140,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=2,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping],
remapped_matrix.eval())
@@ -155,7 +155,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -170,7 +170,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -189,7 +189,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
expected_remapped_matrix = np.reshape(
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_all_missing_rows(self):
@@ -204,7 +204,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
remapped_matrix.eval())
@@ -222,7 +222,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, num_cols)),
remapped_matrix.eval())
@@ -243,7 +243,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(invalid_remapping),
num_cols=self.old_num_cols)
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
# Invalid column remapping.
@@ -255,7 +255,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=self.old_num_rows,
num_cols=len(invalid_remapping))
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
def test_load_and_remap_incorrect_initializing_values(self):
@@ -272,7 +272,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
@@ -284,7 +284,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[0] * 5,
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
@@ -306,7 +306,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase):
initializer=constant_op.constant(np_value, dtype=dtypes.float32),
partitioner=partitioner)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
save = saver.Saver([matrix])
variables.global_variables_initializer().run()