diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/lookup_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/lookup_ops_test.py | 156 |
1 files changed, 78 insertions, 78 deletions
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 5f08339fe5..38b14e34cc 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -36,7 +36,7 @@ from tensorflow.python.training import server_lib class HashTableOpTest(test.TestCase): def testHashTable(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -54,7 +54,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -72,7 +72,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [-1, -1]], result) def testHashTableInitWithPythonArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = ["brain", "salad", "surgery"] values = [0, 1, 2] @@ -90,7 +90,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableInitWithNumPyArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = np.array(["brain", "salad", "surgery"], dtype=np.str) values = np.array([0, 1, 2], dtype=np.int64) @@ -107,7 +107,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMultipleHashTables(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -135,7 +135,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testHashTableWithTensorDefault(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableWithSparseTensorInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -173,7 +173,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual(sp_shape, out_shape) def testSignatureMismatch(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -190,7 +190,7 @@ class HashTableOpTest(test.TestCase): lookup_ops.KeyValueTensorInitializer(keys, values), "UNK") def testDTypes(self): - with self.test_session(): + with self.cached_session(): default_val = -1 with self.assertRaises(TypeError): lookup_ops.HashTable( @@ -198,7 +198,7 @@ class HashTableOpTest(test.TestCase): dtypes.int64), default_val) def testNotInitialized(self): - with self.test_session(): + with self.cached_session(): default_val = -1 table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer( @@ -211,7 +211,7 @@ class HashTableOpTest(test.TestCase): output.eval() def testInitializeTwice(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -223,7 +223,7 @@ class HashTableOpTest(test.TestCase): table.init.run() def testInitializationWithInvalidDimensions(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) @@ -272,7 +272,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -284,7 +284,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_multicolumn_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, @@ -299,7 +299,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_multicolumn_file_custom_delimiter(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, @@ -314,7 +314,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_file = constant_op.constant(vocabulary_file) table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) @@ -328,7 +328,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_placeholder_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) @@ -344,7 +344,7 @@ class IndexTableFromFile(test.TestCase): def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, @@ -359,7 +359,7 @@ class IndexTableFromFile(test.TestCase): def test_int64_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, @@ -374,7 +374,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_default_value(self): default_value = -42 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -385,7 +385,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab5.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1000) ids = table.lookup( @@ -432,7 +432,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -444,7 +444,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, @@ -459,7 +459,7 @@ class IndexTableFromFile(test.TestCase): vocabulary_file=vocabulary_file, vocab_size=0) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -471,7 +471,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_invalid_hashers(self): vocabulary_file = self._createVocabFile("invalid_hasher.txt") - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, @@ -490,14 +490,14 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_table_ref_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab9.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) self.assertIsNotNone(table.table_ref) def test_index_table_from_file_table_ref_without_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab10.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=0) self.assertIsNotNone(table.table_ref) @@ -506,21 +506,21 @@ class IndexTableFromFile(test.TestCase): class KeyValueTensorInitializerTest(test.TestCase): def test_string(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup_ops.KeyValueTensorInitializer( ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) table = lookup_ops.HashTable(init, default_value=-1) table.init.run() def test_int64(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) table = lookup_ops.HashTable(init, default_value=-1) table.init.run() def test_int32(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64) table = lookup_ops.HashTable(init, default_value=-1) @@ -532,7 +532,7 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): def test_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1) ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) @@ -542,7 +542,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) ids = table.lookup( @@ -553,7 +553,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) ids = table.lookup( @@ -565,7 +565,7 @@ class IndexTableFromTensor(test.TestCase): def test_index_table_from_tensor_with_default_value(self): default_value = -42 - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], default_value=default_value) @@ -576,14 +576,14 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_vocabulary_list(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, "vocabulary_list must be specified"): lookup_ops.index_table_from_tensor( vocabulary_list=None, num_oov_buckets=1) def test_index_table_from_tensor_empty_vocabulary_list(self): - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) @@ -593,7 +593,7 @@ class IndexTableFromTensor(test.TestCase): lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], @@ -623,7 +623,7 @@ class IndexToStringTableFromFileTest(test.TestCase): type_funcs = [str, constant_op.constant] for type_func in type_funcs: vocabulary_file = type_func(vocabulary_path) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file) features = table.lookup( @@ -636,7 +636,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_from_multicolumn_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, @@ -650,7 +650,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, @@ -665,7 +665,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -677,7 +677,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_small(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2, @@ -690,7 +690,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -702,7 +702,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -715,7 +715,7 @@ class IndexToStringTableFromFileTest(test.TestCase): class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_table_from_tensor(self): - with self.test_session(): + with self.cached_session(): vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=vocabulary_list) @@ -729,7 +729,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): vocabulary_list = constant_op.constant(["hello", "hello"]) table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=vocabulary_list) @@ -740,7 +740,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.test_session(): + with self.cached_session(): vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=vocabulary_list, default_value=default_value) @@ -764,7 +764,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeStringTable(self): vocabulary_file = self._createVocabFile("one_column_1.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup_ops.HashTable( lookup_ops.TextFileInitializer( @@ -782,7 +782,7 @@ class InitializeTableFromFileOpTest(test.TestCase): vocabulary_file = self._createVocabFile( "one_column_int64.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup_ops.HashTable( lookup_ops.TextFileInitializer( @@ -800,7 +800,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeIndexTable(self): vocabulary_file = self._createVocabFile("one_column_2.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup_ops.TextFileIndex.LINE_NUMBER value_index = lookup_ops.TextFileIndex.WHOLE_LINE @@ -821,7 +821,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 value_index = 2 @@ -843,7 +843,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 2 value_index = 1 @@ -857,7 +857,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidDataType(self): vocabulary_file = self._createVocabFile("one_column_3.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup_ops.TextFileIndex.WHOLE_LINE value_index = lookup_ops.TextFileIndex.LINE_NUMBER @@ -870,7 +870,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidIndex(self): vocabulary_file = self._createVocabFile("one_column_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 # second column of the line value_index = lookup_ops.TextFileIndex.LINE_NUMBER @@ -885,7 +885,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeSameTableWithMultipleNodes(self): vocabulary_file = self._createVocabFile("one_column_5.txt") - with self.test_session() as sess: + with self.cached_session() as sess: shared_name = "shared-one-columm" default_value = -1 table1 = lookup_ops.HashTable( @@ -924,7 +924,7 @@ class InitializeTableFromFileOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testInitializeTableWithNoFilename(self): - with self.test_session(): + with self.cached_session(): default_value = -1 with self.assertRaises(ValueError): lookup_ops.HashTable( @@ -934,7 +934,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value) def testInitializeWithVocabSize(self): - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 vocabulary_file1 = self._createVocabFile("one_column6.txt") @@ -982,7 +982,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testFeedVocabularyName(self): vocabulary_file = self._createVocabFile("feed_vocabulary.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup_ops.HashTable( lookup_ops.TextFileInitializer( @@ -1008,7 +1008,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidFilenames(self): vocabulary_file = self._createVocabFile("filename_shape.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 # Invalid data type @@ -1031,7 +1031,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testIdToStringTable(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" vocab_size = 3 table = lookup_ops.HashTable( @@ -1048,7 +1048,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testStringToIdTable(self): vocab_file = self._createVocabFile("feat_to_id_2.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup_ops.HashTable( @@ -1065,7 +1065,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInt64ToIdTable(self): vocab_file = self._createVocabFile( "feat_to_id_3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup_ops.HashTable( @@ -1090,7 +1090,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testStringIdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1110,7 +1110,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1132,7 +1132,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1151,7 +1151,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(vocab_size + oov_buckets, table.size().eval()) def testStringIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -1172,7 +1172,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testInt32IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -1194,20 +1194,20 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testFloat64IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup_ops.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.float64) def testBoolIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup_ops.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.bool) def testIdTableWithHashBucketsWithMultipleInitializers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value = -1 vocab_size = 3 oov_buckets = 3 @@ -1248,7 +1248,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsInitializationAcrossSessions(self): vocab_file = self._createVocabFile("feat_to_id_5.txt") shared_name = "across-sessions" - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1269,7 +1269,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertAllEqual([0, 1, 2, 3], out1.eval()) self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1292,7 +1292,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): vocab_file = self._createVocabFile("feat_to_id_6.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value1 = -1 vocab_size = 3 oov_buckets = 0 @@ -1328,7 +1328,7 @@ class IdTableWithHashBucketsTest(test.TestCase): vocab_file = self._createVocabFile("feat_to_id_7.txt") input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], @@ -1355,7 +1355,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), @@ -1383,7 +1383,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), @@ -1410,7 +1410,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithInvalidHashers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -1451,7 +1451,7 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup_ops.StrongHashSpec([None, 2])) def testIdTableWithHashBucketsNoInnerTable(self): - with self.test_session(): + with self.cached_session(): table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1) self.assertIsNone(table.table_ref) |