diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/checkpoint_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/checkpoint_ops_test.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py index 7f147ba53a..51611b75af 100644 --- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py +++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py @@ -57,7 +57,7 @@ class GenerateVocabRemappingTest(test.TestCase): new_vocab_offset=0) expected_remapping = range(0, 3) expected_num_present = 3 - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_remapping, remapping.eval()) self.assertAllEqual(expected_num_present, num_present.eval()) @@ -70,7 +70,7 @@ class GenerateVocabRemappingTest(test.TestCase): new_vocab_offset=0) expected_remapping = [2, 0, 1] expected_num_present = 3 - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_remapping, remapping.eval()) self.assertAllEqual(expected_num_present, num_present.eval()) @@ -83,7 +83,7 @@ class GenerateVocabRemappingTest(test.TestCase): new_vocab_offset=1) expected_remapping = [0] expected_num_present = 1 - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_remapping, remapping.eval()) self.assertAllEqual(expected_num_present, num_present.eval()) @@ -98,7 +98,7 @@ class GenerateVocabRemappingTest(test.TestCase): old_vocab_size=2) expected_remapping = [-1, 0, 1] expected_num_present = 2 - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_remapping, remapping.eval()) self.assertAllEqual(expected_num_present, num_present.eval()) @@ -122,7 +122,7 @@ class LoadAndRemapMatrixTest(test.TestCase): self.old_tensor_name = 'some_scope/matrix' save = saver.Saver([matrix]) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint') save.save(sess, self.bundle_file) @@ -140,7 +140,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=2, num_cols=self.old_num_cols) - with self.test_session(): + with self.cached_session(): self.assertAllClose(self.matrix_value[row_remapping], remapped_matrix.eval()) @@ -155,7 +155,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) - with self.test_session(): + with self.cached_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], remapped_matrix.eval()) @@ -170,7 +170,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) - with self.test_session(): + with self.cached_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], remapped_matrix.eval()) @@ -189,7 +189,7 @@ class LoadAndRemapMatrixTest(test.TestCase): expected_remapped_matrix = np.reshape( [33, init_val, init_val, init_val, 1, init_val], [3, 2]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval()) def test_load_and_remap_all_missing_rows(self): @@ -204,7 +204,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=initializing_values, num_rows=num_rows, num_cols=self.old_num_cols) - with self.test_session(): + with self.cached_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, self.old_num_cols)), remapped_matrix.eval()) @@ -222,7 +222,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=initializing_values, num_rows=num_rows, num_cols=num_cols) - with self.test_session(): + with self.cached_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, num_cols)), remapped_matrix.eval()) @@ -243,7 +243,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=len(invalid_remapping), num_cols=self.old_num_cols) - with self.test_session(), self.assertRaises(errors.UnimplementedError): + with self.cached_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval() # Invalid column remapping. @@ -255,7 +255,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=self.old_num_rows, num_cols=len(invalid_remapping)) - with self.test_session(), self.assertRaises(errors.UnimplementedError): + with self.cached_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval() def test_load_and_remap_incorrect_initializing_values(self): @@ -272,7 +272,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[], num_rows=3, num_cols=2) - with self.test_session(), self.assertRaises(errors.InvalidArgumentError): + with self.cached_session(), self.assertRaises(errors.InvalidArgumentError): remapped_matrix.eval() remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( @@ -284,7 +284,7 @@ class LoadAndRemapMatrixTest(test.TestCase): initializing_values=[0] * 5, num_rows=3, num_cols=2) - with self.test_session(), self.assertRaises(errors.InvalidArgumentError): + with self.cached_session(), self.assertRaises(errors.InvalidArgumentError): remapped_matrix.eval() @@ -306,7 +306,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): initializer=constant_op.constant(np_value, dtype=dtypes.float32), partitioner=partitioner) - with self.test_session() as sess: + with self.cached_session() as sess: ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt') save = saver.Saver([matrix]) variables.global_variables_initializer().run() |