diff options
Diffstat (limited to 'tensorflow/contrib/lookup/lookup_ops_test.py')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 336 |
1 files changed, 277 insertions, 59 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 9e9345e875..35b0d1bc44 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -303,13 +303,17 @@ class MutableHashTableOpTest(test.TestCase): def testMutableHashTable(self): with self.cached_session(): default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["tarkus", "tank"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -472,13 +476,18 @@ class MutableHashTableOpTest(test.TestCase): def testMutableHashTableOfTensors(self): with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], + dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["tarkus", "tank"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -624,6 +633,26 @@ class MutableHashTableOpTest(test.TestCase): result = output.eval() self.assertAllEqual([0, 1, 3, -1], result) + def testMutableHashTableRemoveHighRank(self): + with self.test_session(): + default_val = -1 + keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) + + table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["salad", "tarkus"]) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, -1, 3, -1], result) + def testMutableHashTableOfTensorsFindHighRank(self): with self.cached_session(): default_val = constant_op.constant([-1, -1, -1], dtypes.int64) @@ -645,6 +674,30 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual( [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) + def testMutableHashTableOfTensorsRemoveHighRank(self): + with self.test_session(): + default_val = constant_op.constant([-1, -1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], + dtypes.int64) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + remove_string = constant_op.constant([["brain", "tank"]]) + table.remove(remove_string).run() + self.assertAllEqual(2, table.size().eval()) + + input_string = constant_op.constant([["brain", "salad"], + ["surgery", "tank"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2, 3], output.get_shape()) + + result = output.eval() + self.assertAllEqual( + [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) + def testMultipleMutableHashTables(self): with self.cached_session() as sess: default_val = -1 @@ -792,13 +845,22 @@ class MutableDenseHashTableOpTest(test.TestCase): def testBasic(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([12, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant([11, 12, 15], dtypes.int64) @@ -806,17 +868,26 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([3], output.get_shape()) result = output.eval() - self.assertAllEqual([0, 1, -1], result) + self.assertAllEqual([0, -1, -1], result) def testBasicBool(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([True, True, True], dtypes.bool) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([True, True, True, True], dtypes.bool) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.bool, default_value=False, empty_key=0) + dtypes.int64, + dtypes.bool, + default_value=False, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([11, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant([11, 12, 15], dtypes.int64) @@ -824,14 +895,30 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([3], output.get_shape()) result = output.eval() - self.assertAllEqual([True, True, False], result) + self.assertAllEqual([False, True, False], result) + + def testSameEmptyAndDeletedKey(self): + with self.cached_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=42) + self.assertAllEqual(0, table.size().eval()) def testLookupUnknownShape(self): with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) @@ -844,45 +931,60 @@ class MutableDenseHashTableOpTest(test.TestCase): def testMapStringToFloat(self): with self.cached_session(): - keys = constant_op.constant(["a", "b", "c"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) + + keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) default_value = constant_op.constant(-1.5, dtypes.float32) table = lookup.MutableDenseHashTable( dtypes.string, dtypes.float32, default_value=default_value, - empty_key="") + empty_key="", + deleted_key="$") self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["b", "e"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) - input_string = constant_op.constant(["a", "b", "d"], dtypes.string) + input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllClose([0, 1.1, -1.5], result) + self.assertAllClose([0, -1.5, 3.3, -1.5], result) def testMapInt64ToFloat(self): for float_dtype in [dtypes.float32, dtypes.float64]: with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) default_value = constant_op.constant(-1.5, float_dtype) table = lookup.MutableDenseHashTable( - dtypes.int64, float_dtype, default_value=default_value, empty_key=0) + dtypes.int64, + float_dtype, + default_value=default_value, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([12, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) - input_string = constant_op.constant([11, 12, 15], dtypes.int64) + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllClose([0, 1.1, -1.5], result) + self.assertAllClose([0, -1.5, 3.3, -1.5], result) def testVectorValues(self): with self.cached_session(): @@ -895,6 +997,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=0, + deleted_key=-1, initial_num_buckets=4) self.assertAllEqual(0, table.size().eval()) @@ -908,26 +1011,35 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(4, table.size().eval()) self.assertAllEqual(8, len(table.export()[0].eval())) - input_string = constant_op.constant([11, 12, 15], dtypes.int64) + remove_string = constant_op.constant([12, 16], dtypes.int64) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(8, len(table.export()[0].eval())) + + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual( - [3, 4], output.shape, msg="Saw shape: %s" % output.shape) + self.assertAllEqual([4, 4], + output.shape, + msg="Saw shape: %s" % output.shape) result = output.eval() - self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], - result) + self.assertAllEqual( + [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], + result) def testVectorKeys(self): with self.cached_session(): keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) values = constant_op.constant([10, 11, 12], dtypes.int64) empty_key = constant_op.constant([0, 3], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) @@ -940,13 +1052,18 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(4, table.size().eval()) self.assertAllEqual(8, len(table.export()[0].eval())) - input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]], + remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(8, len(table.export()[0].eval())) + + input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllEqual([10, 11, -1], result) + self.assertAllEqual([10, -1, 12, -1], result) def testResize(self): with self.cached_session(): @@ -957,6 +1074,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=0, + deleted_key=-1, initial_num_buckets=4) self.assertAllEqual(0, table.size().eval()) @@ -964,31 +1082,42 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(4, len(table.export()[0].eval())) - keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) + keys2 = constant_op.constant([12, 99], dtypes.int64) + table.remove(keys2).run() + self.assertAllEqual(2, table.size().eval()) + self.assertAllEqual(4, len(table.export()[0].eval())) + + keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) + values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - table.insert(keys2, values2).run() - self.assertAllEqual(7, table.size().eval()) + table.insert(keys3, values3).run() + self.assertAllEqual(6, table.size().eval()) self.assertAllEqual(16, len(table.export()[0].eval())) - keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], + keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], dtypes.int64) - output = table.lookup(keys3) - self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval()) + output = table.lookup(keys4) + self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval()) def testExport(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([1, 2, 3], dtypes.int64) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([1, 2, 3, 4], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=100, + deleted_key=200, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + keys2 = constant_op.constant([12, 15], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) exported_keys, exported_values = table.export() @@ -1005,8 +1134,8 @@ class MutableDenseHashTableOpTest(test.TestCase): pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] # sort by key pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0], - [100, 0], [100, 0], [100, 0]], pairs) + self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], + [100, 0], [100, 0], [200, 2]], pairs) def testSaveRestore(self): save_dir = os.path.join(self.get_temp_dir(), "save_restore") @@ -1015,13 +1144,15 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: default_value = -1 empty_key = 0 - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + deleted_key = -1 + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1030,6 +1161,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([12, 15], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1043,6 +1179,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1062,7 +1199,7 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) + self.assertAllEqual([-1, 0, -1, 2, 3], output.eval()) @test_util.run_in_graph_and_eager_modes def testObjectSaveRestore(self): @@ -1071,6 +1208,7 @@ class MutableDenseHashTableOpTest(test.TestCase): default_value = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) save_table = lookup.MutableDenseHashTable( @@ -1078,6 +1216,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1097,6 +1236,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1124,14 +1264,18 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-2, -3], dtypes.int64) default_value = constant_op.constant([-1, -2], dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], + dtypes.int64) + values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], + dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1140,6 +1284,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1149,12 +1298,14 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-2, -3], dtypes.int64) default_value = constant_op.constant([-1, -2], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1184,14 +1335,17 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], + dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t2", checkpoint=True, initial_num_buckets=32) @@ -1200,6 +1354,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1209,12 +1368,14 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t2", checkpoint=True, initial_num_buckets=64) @@ -1235,7 +1396,7 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant( [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) + self.assertAllEqual([0, 1, -1, 3, -1], output.eval()) def testReprobe(self): with self.cached_session(): @@ -1248,6 +1409,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=0, + deleted_key=-1, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) @@ -1267,7 +1429,11 @@ class MutableDenseHashTableOpTest(test.TestCase): keys = constant_op.constant([11, 0, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=12) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=12, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -1283,19 +1449,35 @@ class MutableDenseHashTableOpTest(test.TestCase): def testErrors(self): with self.cached_session(): table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) # Inserting the empty key returns an error - keys = constant_op.constant([11, 0], dtypes.int64) - values = constant_op.constant([0, 1], dtypes.int64) + keys1 = constant_op.constant([11, 0], dtypes.int64) + values1 = constant_op.constant([0, 1], dtypes.int64) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "empty_key"): - table.insert(keys, values).run() + table.insert(keys1, values1).run() # Looking up the empty key returns an error with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "empty_key"): - table.lookup(keys).eval() + table.lookup(keys1).eval() + + # Inserting the deleted key returns an error + keys2 = constant_op.constant([11, -1], dtypes.int64) + values2 = constant_op.constant([0, 1], dtypes.int64) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table.insert(keys2, values2).run() + + # Looking up the empty key returns an error + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table.lookup(keys2).eval() # Arbitrary tensors of keys are not supported keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) @@ -1312,11 +1494,43 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=17, + deleted_key=-1, initial_num_buckets=12) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Number of buckets must be"): self.assertAllEqual(0, table2.size().eval()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Empty and deleted keys must have same shape"): + table3 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=[1, 2]) + self.assertAllEqual(0, table3.size().eval()) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table4 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=42) + self.assertAllEqual(0, table4.size().eval()) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table5 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=[1, 2, 3], + deleted_key=[1, 2, 3]) + self.assertAllEqual(0, table5.size().eval()) + class IndexTableFromFile(test.TestCase): @@ -2558,7 +2772,11 @@ class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): def _create_table(self): return lookup.MutableDenseHashTable( - dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1) + dtypes.int64, + dtypes.float32, + default_value=0.0, + empty_key=-1, + deleted_key=-2) if __name__ == "__main__": |