diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-10 14:36:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 14:49:41 -0700 |
commit | 890e16594a005fe703a5556530b0dc3e6527fa47 (patch) | |
tree | 99140efb13f392ae13a58f08c08754c61bf66f13 /tensorflow/contrib/lookup | |
parent | 132babebf5b1026cb33cad7c4eb7e03810c2acdf (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336321
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 206 |
1 files changed, 103 insertions, 103 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 0a54bb1f5e..89b538d1ba 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -44,7 +44,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): def testHashTable(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -68,7 +68,7 @@ class HashTableOpTest(test.TestCase): self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval()) def testHashTableFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -86,7 +86,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [-1, -1]], result) def testHashTableInitWithPythonArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = ["brain", "salad", "surgery"] values = [0, 1, 2] @@ -105,7 +105,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableInitWithNumPyArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = np.array(["brain", "salad", "surgery"], dtype=np.str) values = np.array([0, 1, 2], dtype=np.int64) @@ -122,7 +122,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMultipleHashTables(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testHashTableWithTensorDefault(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -165,7 +165,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableWithSparseTensorInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -188,7 +188,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual(sp_shape, out_shape) def testSignatureMismatch(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -210,7 +210,7 @@ class HashTableOpTest(test.TestCase): lookup.KeyValueTensorInitializer(keys, values), "UNK") def testDTypes(self): - with self.test_session(): + with self.cached_session(): default_val = -1 with self.assertRaises(TypeError): lookup.HashTable( @@ -218,7 +218,7 @@ class HashTableOpTest(test.TestCase): dtypes.int64), default_val) def testNotInitialized(self): - with self.test_session(): + with self.cached_session(): default_val = -1 table = lookup.HashTable( lookup.KeyValueTensorInitializer( @@ -232,7 +232,7 @@ class HashTableOpTest(test.TestCase): output.eval() def testInitializeTwice(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -244,7 +244,7 @@ class HashTableOpTest(test.TestCase): table.init.run() def testInitializationWithInvalidDimensions(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) @@ -283,7 +283,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) def testHashTableInt32String(self): - with self.test_session(): + with self.cached_session(): default_val = "n/a" keys = constant_op.constant([0, 1, 2], dtypes.int32) values = constant_op.constant(["brain", "salad", "surgery"]) @@ -301,7 +301,7 @@ class HashTableOpTest(test.TestCase): class MutableHashTableOpTest(test.TestCase): def testMutableHashTable(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -470,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([b"-", b"a", b"b"], output.eval()) def testMutableHashTableOfTensors(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) @@ -500,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) def testMutableHashTableExportInsert(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) @@ -531,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(expected_output, output2.eval()) def testMutableHashTableOfTensorsInvalidShape(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, @@ -563,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) def testMutableHashTableInvalidDefaultValue(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([[-1, -1]], dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) @@ -571,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) def testMutableHashTableDuplicateInsert(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) values = constant_op.constant([0, 1, 2, 3], dtypes.int64) @@ -589,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([3, 1, -1], result) def testMutableHashTableFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -608,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [-1, -1]], result) def testMutableHashTableInsertHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) @@ -625,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, 3, -1], result) def testMutableHashTableOfTensorsFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], @@ -646,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase): [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) def testMultipleMutableHashTables(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -676,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testMutableHashTableWithTensorDefault(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -693,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testSignatureMismatch(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -734,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase): lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") def testMutableHashTableStringFloat(self): - with self.test_session(): + with self.cached_session(): default_val = -1.5 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) @@ -752,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllClose([0, 1.1, default_val], result) def testMutableHashTableIntFloat(self): - with self.test_session(): + with self.cached_session(): default_val = -1.0 keys = constant_op.constant([3, 7, 0], dtypes.int64) values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) @@ -770,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllClose([-1.2, 9.9, default_val], result) def testMutableHashTableInt64String(self): - with self.test_session(): + with self.cached_session(): default_val = "n/a" keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant(["brain", "salad", "surgery"]) @@ -791,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase): class MutableDenseHashTableOpTest(test.TestCase): def testBasic(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -809,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testBasicBool(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([True, True, True], dtypes.bool) table = lookup.MutableDenseHashTable( @@ -827,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([True, True, False], result) def testLookupUnknownShape(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -843,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMapStringToFloat(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant(["a", "b", "c"], dtypes.string) values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) default_value = constant_op.constant(-1.5, dtypes.float32) @@ -866,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase): def testMapInt64ToFloat(self): for float_dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) default_value = constant_op.constant(-1.5, float_dtype) @@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllClose([0, 1.1, -1.5], result) def testVectorValues(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], dtypes.int64) @@ -918,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase): result) def testVectorKeys(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) values = constant_op.constant([10, 11, 12], dtypes.int64) empty_key = constant_op.constant([0, 3], dtypes.int64) @@ -949,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([10, 11, -1], result) def testResize(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -977,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval()) def testExport(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -1238,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) def testReprobe(self): - with self.test_session(): + with self.cached_session(): # Insert 6 keys into a table with 8 buckets. # The values are chosen to make sure collisions occur when using GCC STL keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) @@ -1263,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) def testCustomEmptyKey(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 0, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -1281,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testErrors(self): - with self.test_session(): + with self.cached_session(): table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) @@ -1328,7 +1328,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1339,7 +1339,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_file = constant_op.constant(vocabulary_file) table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) @@ -1353,7 +1353,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_placeholder_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) table = lookup.index_table_from_file( vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) @@ -1370,7 +1370,7 @@ class IndexTableFromFile(test.TestCase): def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, key_dtype=dtypes.int32) @@ -1384,7 +1384,7 @@ class IndexTableFromFile(test.TestCase): def test_int64_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, key_dtype=dtypes.int64) @@ -1398,7 +1398,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_default_value(self): default_value = -42 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1409,7 +1409,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab5.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1000) ids = table.lookup( @@ -1439,7 +1439,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1451,7 +1451,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, @@ -1466,7 +1466,7 @@ class IndexTableFromFile(test.TestCase): vocabulary_file=vocabulary_file, vocab_size=0) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1478,7 +1478,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_invalid_hashers(self): vocabulary_file = self._createVocabFile("invalid_hasher.txt") - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup.index_table_from_file( vocabulary_file=vocabulary_file, @@ -1499,21 +1499,21 @@ class IndexTableFromFile(test.TestCase): class KeyValueTensorInitializerTest(test.TestCase): def test_string(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) table = lookup.HashTable(init, default_value=-1) table.init.run() def test_int64(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) table = lookup.HashTable(init, default_value=-1) table.init.run() def test_int32(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64) table = lookup.HashTable(init, default_value=-1) @@ -1542,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) ids = table.lookup( @@ -1553,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) ids = table.lookup( @@ -1565,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase): def test_index_table_from_tensor_with_default_value(self): default_value = -42 - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1575,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_mapping(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, "mapping must be specified"): lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1) def test_index_table_from_tensor_empty_mapping(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=np.array([], dtype=np.str_), num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) @@ -1590,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase): lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], @@ -1609,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase): class StringToIndexTest(test.TestCase): def test_string_to_index(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) indices = lookup.string_to_index(feats, mapping=mapping_strings) @@ -1620,7 +1620,7 @@ class StringToIndexTest(test.TestCase): self.assertAllEqual((1, 2, -1), indices.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) feats = constant_op.constant(["hello", "hola"]) _ = lookup.string_to_index(feats, mapping=mapping_strings) @@ -1630,7 +1630,7 @@ class StringToIndexTest(test.TestCase): def test_string_to_index_with_default_value(self): default_value = -42 - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) indices = lookup.string_to_index( @@ -1651,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table(self): vocabulary_file = self._createVocabFile("i2f_vocab1.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) @@ -1663,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1675,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_small(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2, @@ -1688,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1700,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1713,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase): class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_table_from_tensor(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) @@ -1727,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) @@ -1738,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings, default_value=default_value) @@ -1754,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): class IndexToStringTest(test.TestCase): def test_index_to_string(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) @@ -1766,7 +1766,7 @@ class IndexToStringTest(test.TestCase): feats.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) @@ -1778,7 +1778,7 @@ class IndexToStringTest(test.TestCase): def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([1, 2, 4], dtypes.int64) feats = lookup.index_to_string( @@ -1818,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase): vocabulary_file = self._createVocabFile( "one_column_int64.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup.HashTable( lookup.TextFileInitializer(vocabulary_file, dtypes.int64, @@ -1837,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeIndexTable(self): vocabulary_file = self._createVocabFile("one_column_2.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup.TextFileIndex.LINE_NUMBER value_index = lookup.TextFileIndex.WHOLE_LINE @@ -1858,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 value_index = 2 @@ -1880,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 2 value_index = 1 @@ -1894,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidDataType(self): vocabulary_file = self._createVocabFile("one_column_3.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup.TextFileIndex.WHOLE_LINE value_index = lookup.TextFileIndex.LINE_NUMBER @@ -1907,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidIndex(self): vocabulary_file = self._createVocabFile("one_column_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 # second column of the line value_index = lookup.TextFileIndex.LINE_NUMBER @@ -1922,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeSameTableWithMultipleNodes(self): vocabulary_file = self._createVocabFile("one_column_5.txt") - with self.test_session() as sess: + with self.cached_session() as sess: shared_name = "shared-one-columm" default_value = -1 table1 = lookup.HashTable( @@ -1961,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testInitializeTableWithNoFilename(self): - with self.test_session(): + with self.cached_session(): default_value = -1 with self.assertRaises(ValueError): lookup.HashTable( @@ -1971,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value) def testInitializeWithVocabSize(self): - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 vocabulary_file1 = self._createVocabFile("one_column6.txt") @@ -2022,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testFeedVocabularyName(self): vocabulary_file = self._createVocabFile("feed_vocabulary.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup.HashTable( lookup.TextFileInitializer("old_file.txt", dtypes.string, @@ -2049,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidFilenames(self): vocabulary_file = self._createVocabFile("filename_shape.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 # Invalid data type @@ -2072,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testIdToStringTable(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" vocab_size = 3 table = lookup.HashTable( @@ -2090,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testStringToIdTable(self): vocab_file = self._createVocabFile("feat_to_id_2.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup.HashTable( @@ -2108,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInt64ToIdTable(self): vocab_file = self._createVocabFile( "feat_to_id_3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup.HashTable( @@ -2133,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testStringIdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2154,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2176,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2196,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(vocab_size + oov_buckets, table.size().eval()) def testStringIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -2217,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testInt32IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -2239,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testFloat64IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.float64) def testBoolIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.bool) def testIdTableWithHashBucketsWithMultipleInitializers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value = -1 vocab_size = 3 oov_buckets = 3 @@ -2294,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsInitializationAcrossSessions(self): vocab_file = self._createVocabFile("feat_to_id_5.txt") shared_name = "across-sessions" - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2316,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertAllEqual([0, 1, 2, 3], out1.eval()) self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2340,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): vocab_file = self._createVocabFile("feat_to_id_6.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value1 = -1 vocab_size = 3 oov_buckets = 0 @@ -2378,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase): vocab_file = self._createVocabFile("feat_to_id_7.txt") input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], @@ -2407,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), @@ -2436,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), @@ -2464,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithInvalidHashers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 |