aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/lookup_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/lookup_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/lookup_ops_test.py156
1 files changed, 78 insertions, 78 deletions
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 5f08339fe5..38b14e34cc 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.training import server_lib
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -54,7 +54,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -72,7 +72,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -90,7 +90,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -107,7 +107,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -135,7 +135,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -173,7 +173,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -190,7 +190,7 @@ class HashTableOpTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup_ops.HashTable(
@@ -198,7 +198,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(
@@ -211,7 +211,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -223,7 +223,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -272,7 +272,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -284,7 +284,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -299,7 +299,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -314,7 +314,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -328,7 +328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -344,7 +344,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -359,7 +359,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -374,7 +374,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -385,7 +385,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -432,7 +432,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -444,7 +444,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -459,7 +459,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -471,7 +471,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -490,14 +490,14 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_table_ref_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
self.assertIsNotNone(table.table_ref)
def test_index_table_from_file_table_ref_without_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=0)
self.assertIsNotNone(table.table_ref)
@@ -506,21 +506,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int32, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
@@ -532,7 +532,7 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
@@ -542,7 +542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -553,7 +553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -565,7 +565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
default_value=default_value)
@@ -576,14 +576,14 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
"vocabulary_list must be specified"):
lookup_ops.index_table_from_tensor(
vocabulary_list=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -593,7 +593,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
@@ -623,7 +623,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
type_funcs = [str, constant_op.constant]
for type_func in type_funcs:
vocabulary_file = type_func(vocabulary_path)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(
@@ -636,7 +636,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -650,7 +650,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -665,7 +665,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -677,7 +677,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -690,7 +690,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -702,7 +702,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -715,7 +715,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -729,7 +729,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["hello", "hello"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -740,7 +740,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list, default_value=default_value)
@@ -764,7 +764,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -782,7 +782,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -800,7 +800,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.LINE_NUMBER
value_index = lookup_ops.TextFileIndex.WHOLE_LINE
@@ -821,7 +821,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -843,7 +843,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -857,7 +857,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.WHOLE_LINE
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -870,7 +870,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -885,7 +885,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup_ops.HashTable(
@@ -924,7 +924,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup_ops.HashTable(
@@ -934,7 +934,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -982,7 +982,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -1008,7 +1008,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -1031,7 +1031,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1048,7 +1048,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1065,7 +1065,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1090,7 +1090,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1110,7 +1110,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1132,7 +1132,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1151,7 +1151,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1172,7 +1172,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1194,20 +1194,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -1248,7 +1248,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1269,7 +1269,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1292,7 +1292,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -1328,7 +1328,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -1355,7 +1355,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -1383,7 +1383,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -1410,7 +1410,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1451,7 +1451,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
def testIdTableWithHashBucketsNoInnerTable(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
self.assertIsNone(table.table_ref)