aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup/lookup_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lookup/lookup_ops_test.py')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py336
1 files changed, 277 insertions, 59 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 9e9345e875..35b0d1bc44 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -303,13 +303,17 @@ class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
with self.cached_session():
default_val = -1
- keys = constant_op.constant(["brain", "salad", "surgery"])
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["tarkus", "tank"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -472,13 +476,18 @@ class MutableHashTableOpTest(test.TestCase):
def testMutableHashTableOfTensors(self):
with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
- keys = constant_op.constant(["brain", "salad", "surgery"])
- values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
+ values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]],
+ dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["tarkus", "tank"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -624,6 +633,26 @@ class MutableHashTableOpTest(test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, 3, -1], result)
+ def testMutableHashTableRemoveHighRank(self):
+ with self.test_session():
+ default_val = -1
+ keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
+ values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["salad", "tarkus"])
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, -1, 3, -1], result)
+
def testMutableHashTableOfTensorsFindHighRank(self):
with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
@@ -645,6 +674,30 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
+ def testMutableHashTableOfTensorsRemoveHighRank(self):
+ with self.test_session():
+ default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
+ dtypes.int64)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ remove_string = constant_op.constant([["brain", "tank"]])
+ table.remove(remove_string).run()
+ self.assertAllEqual(2, table.size().eval())
+
+ input_string = constant_op.constant([["brain", "salad"],
+ ["surgery", "tank"]])
+ output = table.lookup(input_string)
+ self.assertAllEqual([2, 2, 3], output.get_shape())
+
+ result = output.eval()
+ self.assertAllEqual(
+ [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result)
+
def testMultipleMutableHashTables(self):
with self.cached_session() as sess:
default_val = -1
@@ -792,13 +845,22 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
@@ -806,17 +868,26 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([3], output.get_shape())
result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ self.assertAllEqual([0, -1, -1], result)
def testBasicBool(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([True, True, True], dtypes.bool)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([True, True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.bool, default_value=False, empty_key=0)
+ dtypes.int64,
+ dtypes.bool,
+ default_value=False,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([11, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
@@ -824,14 +895,30 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([3], output.get_shape())
result = output.eval()
- self.assertAllEqual([True, True, False], result)
+ self.assertAllEqual([False, True, False], result)
+
+ def testSameEmptyAndDeletedKey(self):
+ with self.cached_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=42)
+ self.assertAllEqual(0, table.size().eval())
def testLookupUnknownShape(self):
with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
@@ -844,45 +931,60 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapStringToFloat(self):
with self.cached_session():
- keys = constant_op.constant(["a", "b", "c"], dtypes.string)
- values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
+
+ keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string)
+ values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
table = lookup.MutableDenseHashTable(
dtypes.string,
dtypes.float32,
default_value=default_value,
- empty_key="")
+ empty_key="",
+ deleted_key="$")
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["b", "e"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
- input_string = constant_op.constant(["a", "b", "d"], dtypes.string)
+ input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllClose([0, 1.1, -1.5], result)
+ self.assertAllClose([0, -1.5, 3.3, -1.5], result)
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
table = lookup.MutableDenseHashTable(
- dtypes.int64, float_dtype, default_value=default_value, empty_key=0)
+ dtypes.int64,
+ float_dtype,
+ default_value=default_value,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
- input_string = constant_op.constant([11, 12, 15], dtypes.int64)
+ input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllClose([0, 1.1, -1.5], result)
+ self.assertAllClose([0, -1.5, 3.3, -1.5], result)
def testVectorValues(self):
with self.cached_session():
@@ -895,6 +997,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
@@ -908,26 +1011,35 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
- input_string = constant_op.constant([11, 12, 15], dtypes.int64)
+ remove_string = constant_op.constant([12, 16], dtypes.int64)
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(8, len(table.export()[0].eval()))
+
+ input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual(
- [3, 4], output.shape, msg="Saw shape: %s" % output.shape)
+ self.assertAllEqual([4, 4],
+ output.shape,
+ msg="Saw shape: %s" % output.shape)
result = output.eval()
- self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],
- result)
+ self.assertAllEqual(
+ [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]],
+ result)
def testVectorKeys(self):
with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
@@ -940,13 +1052,18 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
- input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]],
+ remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64)
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(8, len(table.export()[0].eval()))
+
+ input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]],
dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllEqual([10, 11, -1], result)
+ self.assertAllEqual([10, -1, 12, -1], result)
def testResize(self):
with self.cached_session():
@@ -957,6 +1074,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
@@ -964,31 +1082,42 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(4, len(table.export()[0].eval()))
- keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
- values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
+ keys2 = constant_op.constant([12, 99], dtypes.int64)
+ table.remove(keys2).run()
+ self.assertAllEqual(2, table.size().eval())
+ self.assertAllEqual(4, len(table.export()[0].eval()))
+
+ keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
+ values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
- table.insert(keys2, values2).run()
- self.assertAllEqual(7, table.size().eval())
+ table.insert(keys3, values3).run()
+ self.assertAllEqual(6, table.size().eval())
self.assertAllEqual(16, len(table.export()[0].eval()))
- keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
+ keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
dtypes.int64)
- output = table.lookup(keys3)
- self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
+ output = table.lookup(keys4)
+ self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([1, 2, 3], dtypes.int64)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([1, 2, 3, 4], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=100,
+ deleted_key=200,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ keys2 = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
exported_keys, exported_values = table.export()
@@ -1005,8 +1134,8 @@ class MutableDenseHashTableOpTest(test.TestCase):
pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
# sort by key
pairs = pairs[pairs[:, 0].argsort()]
- self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0],
- [100, 0], [100, 0], [100, 0]], pairs)
+ self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0],
+ [100, 0], [100, 0], [200, 2]], pairs)
def testSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
@@ -1015,13 +1144,15 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
default_value = -1
empty_key = 0
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ deleted_key = -1
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1030,6 +1161,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1043,6 +1179,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1062,7 +1199,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())
@test_util.run_in_graph_and_eager_modes
def testObjectSaveRestore(self):
@@ -1071,6 +1208,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
default_value = -1
empty_key = 0
+ deleted_key = -1
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
save_table = lookup.MutableDenseHashTable(
@@ -1078,6 +1216,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1097,6 +1236,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1124,14 +1264,18 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-2, -3], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
- keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
- values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
+ keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
+ dtypes.int64)
+ values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]],
+ dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1140,6 +1284,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1149,12 +1298,14 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-2, -3], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1184,14 +1335,17 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
- keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
+ dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t2",
checkpoint=True,
initial_num_buckets=32)
@@ -1200,6 +1354,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1209,12 +1368,14 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t2",
checkpoint=True,
initial_num_buckets=64)
@@ -1235,7 +1396,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant(
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
+ self.assertAllEqual([0, 1, -1, 3, -1], output.eval())
def testReprobe(self):
with self.cached_session():
@@ -1248,6 +1409,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
@@ -1267,7 +1429,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=12)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=12,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
@@ -1283,19 +1449,35 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testErrors(self):
with self.cached_session():
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
# Inserting the empty key returns an error
- keys = constant_op.constant([11, 0], dtypes.int64)
- values = constant_op.constant([0, 1], dtypes.int64)
+ keys1 = constant_op.constant([11, 0], dtypes.int64)
+ values1 = constant_op.constant([0, 1], dtypes.int64)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
- table.insert(keys, values).run()
+ table.insert(keys1, values1).run()
# Looking up the empty key returns an error
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
- table.lookup(keys).eval()
+ table.lookup(keys1).eval()
+
+ # Inserting the deleted key returns an error
+ keys2 = constant_op.constant([11, -1], dtypes.int64)
+ values2 = constant_op.constant([0, 1], dtypes.int64)
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table.insert(keys2, values2).run()
+
+ # Looking up the empty key returns an error
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table.lookup(keys2).eval()
# Arbitrary tensors of keys are not supported
keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
@@ -1312,11 +1494,43 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=17,
+ deleted_key=-1,
initial_num_buckets=12)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Number of buckets must be"):
self.assertAllEqual(0, table2.size().eval())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Empty and deleted keys must have same shape"):
+ table3 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=[1, 2])
+ self.assertAllEqual(0, table3.size().eval())
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "Empty and deleted keys cannot be equal"):
+ table4 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=42)
+ self.assertAllEqual(0, table4.size().eval())
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "Empty and deleted keys cannot be equal"):
+ table5 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=[1, 2, 3],
+ deleted_key=[1, 2, 3])
+ self.assertAllEqual(0, table5.size().eval())
+
class IndexTableFromFile(test.TestCase):
@@ -2558,7 +2772,11 @@ class MutableDenseHashTableBenchmark(MutableHashTableBenchmark):
def _create_table(self):
return lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1)
+ dtypes.int64,
+ dtypes.float32,
+ default_value=0.0,
+ empty_key=-1,
+ deleted_key=-2)
if __name__ == "__main__":