aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:36:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 14:49:41 -0700
commit890e16594a005fe703a5556530b0dc3e6527fa47 (patch)
tree99140efb13f392ae13a58f08c08754c61bf66f13 /tensorflow/contrib/lookup
parent132babebf5b1026cb33cad7c4eb7e03810c2acdf (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336321
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py206
1 files changed, 103 insertions, 103 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0a54bb1f5e..89b538d1ba 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -44,7 +44,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -68,7 +68,7 @@ class HashTableOpTest(test.TestCase):
self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -86,7 +86,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -105,7 +105,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -122,7 +122,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -165,7 +165,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -188,7 +188,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -210,7 +210,7 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup.HashTable(
@@ -218,7 +218,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup.HashTable(
lookup.KeyValueTensorInitializer(
@@ -232,7 +232,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -244,7 +244,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -283,7 +283,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testHashTableInt32String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int32)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -301,7 +301,7 @@ class HashTableOpTest(test.TestCase):
class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -470,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
def testMutableHashTableOfTensors(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -500,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
def testMutableHashTableExportInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -531,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(expected_output, output2.eval())
def testMutableHashTableOfTensorsInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
@@ -563,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testMutableHashTableInvalidDefaultValue(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
@@ -571,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
def testMutableHashTableDuplicateInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@@ -589,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([3, 1, -1], result)
def testMutableHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -608,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableInsertHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
@@ -625,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, 3, -1], result)
def testMutableHashTableOfTensorsFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
@@ -646,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase):
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
def testMultipleMutableHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -676,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testMutableHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -693,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -734,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase):
lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
def testMutableHashTableStringFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.5
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
@@ -752,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, default_val], result)
def testMutableHashTableIntFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.0
keys = constant_op.constant([3, 7, 0], dtypes.int64)
values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
@@ -770,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([-1.2, 9.9, default_val], result)
def testMutableHashTableInt64String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int64)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -791,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase):
class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -809,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testBasicBool(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
@@ -827,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([True, True, False], result)
def testLookupUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -843,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMapStringToFloat(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant(["a", "b", "c"], dtypes.string)
values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
@@ -866,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
@@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, -1.5], result)
def testVectorValues(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
dtypes.int64)
@@ -918,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
result)
def testVectorKeys(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
@@ -949,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([10, 11, -1], result)
def testResize(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -977,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1238,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
def testReprobe(self):
- with self.test_session():
+ with self.cached_session():
# Insert 6 keys into a table with 8 buckets.
# The values are chosen to make sure collisions occur when using GCC STL
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
@@ -1263,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
def testCustomEmptyKey(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1281,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testErrors(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
@@ -1328,7 +1328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1339,7 +1339,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -1353,7 +1353,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -1370,7 +1370,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int32)
@@ -1384,7 +1384,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int64)
@@ -1398,7 +1398,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1409,7 +1409,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -1439,7 +1439,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1451,7 +1451,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -1466,7 +1466,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1478,7 +1478,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -1499,21 +1499,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
@@ -1542,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -1553,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -1565,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1575,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_mapping(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -1590,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
@@ -1609,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase):
class StringToIndexTest(test.TestCase):
def test_string_to_index(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1620,7 +1620,7 @@ class StringToIndexTest(test.TestCase):
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
feats = constant_op.constant(["hello", "hola"])
_ = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1630,7 +1630,7 @@ class StringToIndexTest(test.TestCase):
def test_string_to_index_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(
@@ -1651,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table(self):
vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
@@ -1663,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1675,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -1688,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1700,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1713,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1727,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1738,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_value)
@@ -1754,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
class IndexToStringTest(test.TestCase):
def test_index_to_string(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1766,7 +1766,7 @@ class IndexToStringTest(test.TestCase):
feats.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1778,7 +1778,7 @@ class IndexToStringTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([1, 2, 4], dtypes.int64)
feats = lookup.index_to_string(
@@ -1818,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
@@ -1837,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.LINE_NUMBER
value_index = lookup.TextFileIndex.WHOLE_LINE
@@ -1858,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -1880,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -1894,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.WHOLE_LINE
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1907,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1922,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup.HashTable(
@@ -1961,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup.HashTable(
@@ -1971,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -2022,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer("old_file.txt", dtypes.string,
@@ -2049,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -2072,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup.HashTable(
@@ -2090,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2108,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2133,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2154,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2176,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2196,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2217,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2239,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -2294,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2316,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2340,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -2378,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -2407,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -2436,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -2464,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1