From 6c391166b8b6ba43d2b0151e6fb9cf14864131a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Oct 2018 16:23:35 -0700 Subject: Add 'remove' operation to MutableHashTable and MutableDenseHashTable. PiperOrigin-RevId: 216443201 --- .../linear_optimizer/python/ops/sdca_ops.py | 3 +- .../python/ops/sharded_mutable_dense_hashtable.py | 2 + .../ops/sharded_mutable_dense_hashtable_test.py | 6 + tensorflow/contrib/lookup/lookup_ops.py | 81 ++++- tensorflow/contrib/lookup/lookup_ops_test.py | 336 +++++++++++++++++---- .../timeseries/python/timeseries/math_utils.py | 19 +- .../python/timeseries/math_utils_test.py | 8 +- .../python/timeseries/state_management.py | 1 + .../base_api/api_def_LookupTableRemoveV2.pbtxt | 24 ++ tensorflow/core/framework/lookup_interface.cc | 8 + tensorflow/core/framework/lookup_interface.h | 17 ++ .../core/kernels/initializable_lookup_table.h | 6 + tensorflow/core/kernels/lookup_table_op.cc | 184 ++++++++++- tensorflow/core/ops/compat/ops_history.v1.pbtxt | 20 ++ tensorflow/core/ops/lookup_ops.cc | 14 + 15 files changed, 643 insertions(+), 86 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt (limited to 'tensorflow') diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 48ac429701..b5099a0bf6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -152,7 +152,8 @@ class SdcaModel(object): default_value=[0.0, 0.0, 0.0, 0.0], # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe # empty_key (that will never collide with actual payloads). - empty_key=[0, 0]) + empty_key=[0, 0], + deleted_key=[1, 1]) summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) summary.scalar('examples_seen', self._hashtable.size()) diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 5015fb0848..44a869f7c2 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype, default_value, empty_key, + deleted_key, num_shards=1, checkpoint=True, name='ShardedMutableHashTable'): @@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 553b116a3b..2b56d0fa3a 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = ShardedMutableDenseHashTable( @@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] + deleted_key = [1, 0] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], dtypes.int64) values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], @@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.float32, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testExportSharded(self): with self.cached_session(): empty_key = -2 + deleted_key = -3 default_val = -1 num_shards = 2 keys = constant_op.constant([10, 11, 12], dtypes.int64) @@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index f83765a48d..5abef822e8 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -292,8 +292,8 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation. - Data can be inserted by calling the insert method. It does not support - initialization via the init method. + Data can be inserted by calling the insert method and removed by calling the + remove method. It does not support initialization via the init method. Example usage: @@ -391,6 +391,34 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): with ops.colocate_with(self._table_ref): return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) + def remove(self, keys, name=None): + """Removes `keys` and its associated values from the table. + + If a key is not present in the table, it is silently ignored. + + Args: + keys: Keys to remove. Can be a tensor of any shape. Must match the table's + key type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + if keys.dtype != self._key_dtype: + raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % + (self._key_dtype, keys.dtype)) + + with ops.name_scope(name, "%s_lookup_table_remove" % self._name, + (self._table_ref, keys, self._default_value)) as name: + # pylint: disable=protected-access + op = gen_lookup_ops.lookup_table_remove_v2( + self._table_ref, keys, name=name) + + return op + def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -487,11 +515,11 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation using tensors as backing store. - Data can be inserted by calling the insert method. It does not support - initialization via the init method. + Data can be inserted by calling the insert method and removed by calling the + remove method. It does not support initialization via the init method. It uses "open addressing" with quadratic reprobing to resolve collisions. - Compared to `MutableHashTable` the insert and lookup operations in a + Compared to `MutableHashTable` the insert, remove and lookup operations in a `MutableDenseHashTable` are typically faster, but memory usage can be higher. However, `MutableDenseHashTable` does not require additional memory for temporary tensors created during checkpointing and restore operations. @@ -502,7 +530,9 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1, - empty_key=0) + empty_key=0, + deleted_key=-1) + sess.run(table.insert(keys, values)) out = table.lookup(query_keys) print(out.eval()) @@ -516,6 +546,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): value_dtype, default_value, empty_key, + deleted_key, initial_num_buckets=None, shared_name=None, name="MutableDenseHashTable", @@ -530,7 +561,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. empty_key: the key to use to represent empty buckets internally. Must not - be used in insert or lookup operations. + be used in insert, remove or lookup operations. initial_num_buckets: the initial number of buckets. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. @@ -538,9 +569,12 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): checkpoint: if True, the contents of the table are saved to and restored from checkpoints. If `shared_name` is empty for a checkpointed table, it is shared using the table node name. + deleted_key: the key to use to represent deleted buckets internally. Must + not be used in insert, remove or lookup operations and be different from + the empty_key. Returns: - A `MutableHashTable` object. + A `MutableDenseHashTable` object. Raises: ValueError: If checkpoint is True and no name was specified. @@ -555,6 +589,8 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor( empty_key, dtype=key_dtype, name="empty_key") + deleted_key = ops.convert_to_tensor( + deleted_key, dtype=key_dtype, name="deleted_key") executing_eagerly = context.executing_eagerly() if executing_eagerly and shared_name is None: # TODO(allenl): This will leak memory due to kernel caching by the @@ -564,6 +600,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): shared_name = "table_%d" % (ops.uid(),) self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, + deleted_key=deleted_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, value_dtype=value_dtype, @@ -648,6 +685,34 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): self._table_ref, keys, values, name=name) return op + def remove(self, keys, name=None): + """Removes `keys` and its associated values from the table. + + If a key is not present in the table, it is silently ignored. + + Args: + keys: Keys to remove. Can be a tensor of any shape. Must match the table's + key type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + if keys.dtype != self._key_dtype: + raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % + (self._key_dtype, keys.dtype)) + + with ops.name_scope(name, "%s_lookup_table_remove" % self._name, + (self._table_ref, keys, self._default_value)) as name: + # pylint: disable=protected-access + op = gen_lookup_ops.lookup_table_remove_v2( + self._table_ref, keys, name=name) + + return op + def export(self, name=None): """Returns tensors of all keys and values in the table. diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 9e9345e875..35b0d1bc44 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -303,13 +303,17 @@ class MutableHashTableOpTest(test.TestCase): def testMutableHashTable(self): with self.cached_session(): default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["tarkus", "tank"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -472,13 +476,18 @@ class MutableHashTableOpTest(test.TestCase): def testMutableHashTableOfTensors(self): with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], + dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["tarkus", "tank"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -624,6 +633,26 @@ class MutableHashTableOpTest(test.TestCase): result = output.eval() self.assertAllEqual([0, 1, 3, -1], result) + def testMutableHashTableRemoveHighRank(self): + with self.test_session(): + default_val = -1 + keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) + + table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["salad", "tarkus"]) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, -1, 3, -1], result) + def testMutableHashTableOfTensorsFindHighRank(self): with self.cached_session(): default_val = constant_op.constant([-1, -1, -1], dtypes.int64) @@ -645,6 +674,30 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual( [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) + def testMutableHashTableOfTensorsRemoveHighRank(self): + with self.test_session(): + default_val = constant_op.constant([-1, -1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], + dtypes.int64) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + remove_string = constant_op.constant([["brain", "tank"]]) + table.remove(remove_string).run() + self.assertAllEqual(2, table.size().eval()) + + input_string = constant_op.constant([["brain", "salad"], + ["surgery", "tank"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2, 3], output.get_shape()) + + result = output.eval() + self.assertAllEqual( + [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) + def testMultipleMutableHashTables(self): with self.cached_session() as sess: default_val = -1 @@ -792,13 +845,22 @@ class MutableDenseHashTableOpTest(test.TestCase): def testBasic(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([12, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant([11, 12, 15], dtypes.int64) @@ -806,17 +868,26 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([3], output.get_shape()) result = output.eval() - self.assertAllEqual([0, 1, -1], result) + self.assertAllEqual([0, -1, -1], result) def testBasicBool(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([True, True, True], dtypes.bool) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([True, True, True, True], dtypes.bool) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.bool, default_value=False, empty_key=0) + dtypes.int64, + dtypes.bool, + default_value=False, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([11, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant([11, 12, 15], dtypes.int64) @@ -824,14 +895,30 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([3], output.get_shape()) result = output.eval() - self.assertAllEqual([True, True, False], result) + self.assertAllEqual([False, True, False], result) + + def testSameEmptyAndDeletedKey(self): + with self.cached_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=42) + self.assertAllEqual(0, table.size().eval()) def testLookupUnknownShape(self): with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) @@ -844,45 +931,60 @@ class MutableDenseHashTableOpTest(test.TestCase): def testMapStringToFloat(self): with self.cached_session(): - keys = constant_op.constant(["a", "b", "c"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) + + keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) default_value = constant_op.constant(-1.5, dtypes.float32) table = lookup.MutableDenseHashTable( dtypes.string, dtypes.float32, default_value=default_value, - empty_key="") + empty_key="", + deleted_key="$") self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant(["b", "e"]) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) - input_string = constant_op.constant(["a", "b", "d"], dtypes.string) + input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllClose([0, 1.1, -1.5], result) + self.assertAllClose([0, -1.5, 3.3, -1.5], result) def testMapInt64ToFloat(self): for float_dtype in [dtypes.float32, dtypes.float64]: with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) default_value = constant_op.constant(-1.5, float_dtype) table = lookup.MutableDenseHashTable( - dtypes.int64, float_dtype, default_value=default_value, empty_key=0) + dtypes.int64, + float_dtype, + default_value=default_value, + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + remove_string = constant_op.constant([12, 15], dtypes.int64) + table.remove(remove_string).run() self.assertAllEqual(3, table.size().eval()) - input_string = constant_op.constant([11, 12, 15], dtypes.int64) + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllClose([0, 1.1, -1.5], result) + self.assertAllClose([0, -1.5, 3.3, -1.5], result) def testVectorValues(self): with self.cached_session(): @@ -895,6 +997,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=0, + deleted_key=-1, initial_num_buckets=4) self.assertAllEqual(0, table.size().eval()) @@ -908,26 +1011,35 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(4, table.size().eval()) self.assertAllEqual(8, len(table.export()[0].eval())) - input_string = constant_op.constant([11, 12, 15], dtypes.int64) + remove_string = constant_op.constant([12, 16], dtypes.int64) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(8, len(table.export()[0].eval())) + + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual( - [3, 4], output.shape, msg="Saw shape: %s" % output.shape) + self.assertAllEqual([4, 4], + output.shape, + msg="Saw shape: %s" % output.shape) result = output.eval() - self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], - result) + self.assertAllEqual( + [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], + result) def testVectorKeys(self): with self.cached_session(): keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) values = constant_op.constant([10, 11, 12], dtypes.int64) empty_key = constant_op.constant([0, 3], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) @@ -940,13 +1052,18 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(4, table.size().eval()) self.assertAllEqual(8, len(table.export()[0].eval())) - input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]], + remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) + table.remove(remove_string).run() + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(8, len(table.export()[0].eval())) + + input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([4], output.get_shape()) result = output.eval() - self.assertAllEqual([10, 11, -1], result) + self.assertAllEqual([10, -1, 12, -1], result) def testResize(self): with self.cached_session(): @@ -957,6 +1074,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=0, + deleted_key=-1, initial_num_buckets=4) self.assertAllEqual(0, table.size().eval()) @@ -964,31 +1082,42 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(4, len(table.export()[0].eval())) - keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) + keys2 = constant_op.constant([12, 99], dtypes.int64) + table.remove(keys2).run() + self.assertAllEqual(2, table.size().eval()) + self.assertAllEqual(4, len(table.export()[0].eval())) + + keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) + values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - table.insert(keys2, values2).run() - self.assertAllEqual(7, table.size().eval()) + table.insert(keys3, values3).run() + self.assertAllEqual(6, table.size().eval()) self.assertAllEqual(16, len(table.export()[0].eval())) - keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], + keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], dtypes.int64) - output = table.lookup(keys3) - self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval()) + output = table.lookup(keys4) + self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval()) def testExport(self): with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([1, 2, 3], dtypes.int64) + + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([1, 2, 3, 4], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=100, + deleted_key=200, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + + keys2 = constant_op.constant([12, 15], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) exported_keys, exported_values = table.export() @@ -1005,8 +1134,8 @@ class MutableDenseHashTableOpTest(test.TestCase): pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] # sort by key pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0], - [100, 0], [100, 0], [100, 0]], pairs) + self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], + [100, 0], [100, 0], [200, 2]], pairs) def testSaveRestore(self): save_dir = os.path.join(self.get_temp_dir(), "save_restore") @@ -1015,13 +1144,15 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: default_value = -1 empty_key = 0 - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + deleted_key = -1 + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1030,6 +1161,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([12, 15], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1043,6 +1179,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1062,7 +1199,7 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) + self.assertAllEqual([-1, 0, -1, 2, 3], output.eval()) @test_util.run_in_graph_and_eager_modes def testObjectSaveRestore(self): @@ -1071,6 +1208,7 @@ class MutableDenseHashTableOpTest(test.TestCase): default_value = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) save_table = lookup.MutableDenseHashTable( @@ -1078,6 +1216,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1097,6 +1236,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1124,14 +1264,18 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-2, -3], dtypes.int64) default_value = constant_op.constant([-1, -2], dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], + dtypes.int64) + values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], + dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=32) @@ -1140,6 +1284,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1149,12 +1298,14 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-2, -3], dtypes.int64) default_value = constant_op.constant([-1, -2], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t1", checkpoint=True, initial_num_buckets=64) @@ -1184,14 +1335,17 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) + keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], + dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t2", checkpoint=True, initial_num_buckets=32) @@ -1200,6 +1354,11 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() + self.assertAllEqual(4, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) + table.remove(keys2).run() self.assertAllEqual(3, table.size().eval()) self.assertAllEqual(32, len(table.export()[0].eval())) @@ -1209,12 +1368,14 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name="t2", checkpoint=True, initial_num_buckets=64) @@ -1235,7 +1396,7 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant( [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) + self.assertAllEqual([0, 1, -1, 3, -1], output.eval()) def testReprobe(self): with self.cached_session(): @@ -1248,6 +1409,7 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=0, + deleted_key=-1, initial_num_buckets=8) self.assertAllEqual(0, table.size().eval()) @@ -1267,7 +1429,11 @@ class MutableDenseHashTableOpTest(test.TestCase): keys = constant_op.constant([11, 0, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=12) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=12, + deleted_key=-1) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -1283,19 +1449,35 @@ class MutableDenseHashTableOpTest(test.TestCase): def testErrors(self): with self.cached_session(): table = lookup.MutableDenseHashTable( - dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) # Inserting the empty key returns an error - keys = constant_op.constant([11, 0], dtypes.int64) - values = constant_op.constant([0, 1], dtypes.int64) + keys1 = constant_op.constant([11, 0], dtypes.int64) + values1 = constant_op.constant([0, 1], dtypes.int64) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "empty_key"): - table.insert(keys, values).run() + table.insert(keys1, values1).run() # Looking up the empty key returns an error with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "empty_key"): - table.lookup(keys).eval() + table.lookup(keys1).eval() + + # Inserting the deleted key returns an error + keys2 = constant_op.constant([11, -1], dtypes.int64) + values2 = constant_op.constant([0, 1], dtypes.int64) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table.insert(keys2, values2).run() + + # Looking up the empty key returns an error + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "deleted_key"): + table.lookup(keys2).eval() # Arbitrary tensors of keys are not supported keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) @@ -1312,11 +1494,43 @@ class MutableDenseHashTableOpTest(test.TestCase): dtypes.int64, default_value=-1, empty_key=17, + deleted_key=-1, initial_num_buckets=12) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Number of buckets must be"): self.assertAllEqual(0, table2.size().eval()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Empty and deleted keys must have same shape"): + table3 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=[1, 2]) + self.assertAllEqual(0, table3.size().eval()) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table4 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=42) + self.assertAllEqual(0, table4.size().eval()) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table5 = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=[1, 2, 3], + deleted_key=[1, 2, 3]) + self.assertAllEqual(0, table5.size().eval()) + class IndexTableFromFile(test.TestCase): @@ -2558,7 +2772,11 @@ class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): def _create_table(self): return lookup.MutableDenseHashTable( - dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1) + dtypes.int64, + dtypes.float32, + default_value=0.0, + empty_key=-1, + deleted_key=-2) if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 03da2b82e5..9c585fe6a7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -543,20 +543,25 @@ class TupleOfTensorsLookup(lookup.LookupInterface): overhead. """ - def __init__( - self, key_dtype, default_values, empty_key, name, checkpoint=True): + def __init__(self, + key_dtype, + default_values, + empty_key, + deleted_key, + name, + checkpoint=True): default_values_flat = nest.flatten(default_values) - self._hash_tables = nest.pack_sequence_as( - default_values, - [TensorValuedMutableDenseHashTable( + self._hash_tables = nest.pack_sequence_as(default_values, [ + TensorValuedMutableDenseHashTable( key_dtype=key_dtype, value_dtype=default_value.dtype.base_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, name=name + "_{}".format(table_number), checkpoint=checkpoint) - for table_number, default_value - in enumerate(default_values_flat)]) + for table_number, default_value in enumerate(default_values_flat) + ]) self._name = name def lookup(self, keys): diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py index c0de42b15b..91265b9b2e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py @@ -223,10 +223,12 @@ class TestLookupTable(test.TestCase): hash_table = math_utils.TupleOfTensorsLookup( key_dtype=dtypes.int64, default_values=[[ - array_ops.ones([3, 2], dtype=dtypes.float32), array_ops.zeros( - [5], dtype=dtypes.float64) - ], array_ops.ones([7, 7], dtype=dtypes.int64)], + array_ops.ones([3, 2], dtype=dtypes.float32), + array_ops.zeros([5], dtype=dtypes.float64) + ], + array_ops.ones([7, 7], dtype=dtypes.int64)], empty_key=-1, + deleted_key=-2, name="test_lookup") def stack_tensor(base_tensor): return array_ops.stack([base_tensor + 1, base_tensor + 2]) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management.py b/tensorflow/contrib/timeseries/python/timeseries/state_management.py index 13eecd4d82..138406c616 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_management.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_management.py @@ -149,6 +149,7 @@ class ChainingStateManager(_OverridableStateManager): key_dtype=dtypes.int64, default_values=self._start_state, empty_key=-1, + deleted_key=-2, name="cached_states", checkpoint=self._checkpoint_state) diff --git a/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt new file mode 100644 index 0000000000..333fe6f4b2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LookupTableRemoveV2" + visibility: HIDDEN + endpoint { + name: "LookupTableRemove" + } + in_arg { + name: "table_handle" + description: <(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface { return DoInsert(false, keys, values); } + Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + const auto key_values = keys.flat(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface { empty_key_input->template shaped({1, key_shape_.num_elements()}), 0); + const Tensor* deleted_key_input; + OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input)); + OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()), + errors::InvalidArgument( + "Empty and deleted keys must have same shape, got shapes: ", + key_shape_.DebugString(), " and ", + deleted_key_input->shape().DebugString())); + deleted_key_ = PersistentTensor(*deleted_key_input); + deleted_key_hash_ = HashKey(deleted_key_input->template shaped( + {1, key_shape_.num_elements()}), + 0); + + if (empty_key_hash_ == deleted_key_hash_) { + const int64 key_size = key_shape_.num_elements(); + const auto empty_key_matrix = + empty_key_.AccessTensor(ctx)->template shaped({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped({1, key_size}); + OP_REQUIRES( + ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0), + errors::InvalidArgument("Empty and deleted keys cannot be equal")); + } + int64 initial_num_buckets; OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets", &initial_num_buckets)); @@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix(); const auto empty_key_matrix = empty_key_.AccessTensor(ctx)->template shaped({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped({1, key_size}); const int64 bit_mask = num_buckets_ - 1; // TODO(andreasst): parallelize using work_sharder for (int64 i = 0; i < num_elements; ++i) { @@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface { return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface { return DoInsert(ctx, key, value, false); } + Status Remove(OpKernelContext* ctx, const Tensor& key) override + LOCKS_EXCLUDED(mu_) { + if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { + TensorShape expected_shape({key.dim_size(0)}); + expected_shape.AppendShape(key_shape_); + return errors::InvalidArgument("Expected key shape ", + expected_shape.DebugString(), " got ", + key.shape().DebugString()); + } + mutex_lock l(mu_); + return DoRemove(ctx, key); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); num_buckets_ = keys.dim_size(0); key_buckets_ = PersistentTensor(keys); value_buckets_ = PersistentTensor(values); - // Count the number of keys that are not the empty_key. This requires - // iterating through the whole table but that is OK as we only execute it - // during checkpoint restore. + // Count the number of keys that are not the empty_key or deleted_key. + // This requires iterating through the whole table but that is OK as we + // only execute it during checkpoint restore. num_entries_ = 0; const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped( {1, key_shape_.num_elements()}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped( + {1, key_shape_.num_elements()}); const auto key_buckets_tensor = key_buckets_.AccessTensor(ctx)->template matrix(); for (int64 i = 0; i < num_buckets_; ++i) { - if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) { + if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) && + !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) { ++num_entries_; } } @@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, - bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + bool ignore_empty_and_deleted_key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); @@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix(); const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped({1, key_size}); const int64 bit_mask = num_buckets_ - 1; for (int64 i = 0; i < num_elements; ++i) { const uint64 key_hash = HashKey(key_matrix, i); if (empty_key_hash_ == key_hash && IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { - if (ignore_empty_key) { + if (ignore_empty_and_deleted_key) { continue; } return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + if (ignore_empty_and_deleted_key) { + continue; + } + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface { } break; } - if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) || + IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor, + 0)) { ++num_entries_; for (int64 j = 0; j < key_size; ++j) { key_buckets_matrix(bucket_index, j) = @@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface { return Status::OK(); } + Status DoRemove(OpKernelContext* ctx, const Tensor& key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const int64 num_elements = key.dim_size(0); + const int64 key_size = key_shape_.num_elements(); + const auto key_matrix = key.shaped({num_elements, key_size}); + + auto key_buckets_matrix = + key_buckets_.AccessTensor(ctx)->template matrix(); + const auto empty_key_tensor = + empty_key_.AccessTensor(ctx)->template shaped({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped({1, key_size}); + const auto deleted_key_flat = + deleted_key_.AccessTensor(ctx)->template flat(); + const int64 bit_mask = num_buckets_ - 1; + for (int64 i = 0; i < num_elements; ++i) { + const uint64 key_hash = HashKey(key_matrix, i); + if (empty_key_hash_ == key_hash && + IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the empty_key as a table key is not allowed"); + } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } + int64 bucket_index = key_hash & bit_mask; + int64 num_probes = 0; + while (true) { + if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { + --num_entries_; + for (int64 j = 0; j < key_size; ++j) { + key_buckets_matrix(bucket_index, j) = + SubtleMustCopyIfIntegral(deleted_key_flat(j)); + } + break; + } + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + break; + } + ++num_probes; + bucket_index = + (bucket_index + num_probes) & bit_mask; // quadratic probing + if (num_probes >= num_buckets_) { + return errors::Internal( + "Internal error in MutableDenseHashTable remove"); + } + } + } + return Status::OK(); + } + Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (new_num_buckets < 4 || @@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface { PersistentTensor value_buckets_ GUARDED_BY(mu_); PersistentTensor empty_key_; uint64 empty_key_hash_; -}; + PersistentTensor deleted_key_; + uint64 deleted_key_hash_; +}; // namespace lookup } // namespace lookup @@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), LookupTableInsertOp); +// Table remove op. +class LookupTableRemoveOp : public OpKernel { + public: + explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor& key = ctx->input(1); + OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->Remove(ctx, key)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU), + LookupTableRemoveOp); + // Op that returns the size of the given table. class LookupTableSizeOp : public OpKernel { public: diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cfb1055d3c..415e15b720 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -30320,6 +30320,22 @@ op { } is_stateful: true } +op { + name: "LookupTableRemoveV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tin" + } + attr { + name: "Tin" + type: "type" + } + is_stateful: true +} op { name: "LookupTableSize" input_arg { @@ -36706,6 +36722,10 @@ op { name: "empty_key" type_attr: "key_dtype" } + input_arg { + name: "deleted_key" + type_attr: "key_dtype" + } output_arg { name: "table_handle" type: DT_RESOURCE diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 72a77be70d..a0987cd982 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -214,6 +214,19 @@ REGISTER_OP("LookupTableInsertV2") return Status::OK(); }); +REGISTER_OP("LookupTableRemoveV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Attr("Tin: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); + + // TODO(turboale): Validate keys shape. + return Status::OK(); + }); + REGISTER_OP("LookupTableSize") .Input("table_handle: Ref(string)") .Output("size: int64") @@ -407,6 +420,7 @@ REGISTER_OP("MutableDenseHashTable") REGISTER_OP("MutableDenseHashTableV2") .Input("empty_key: key_dtype") + .Input("deleted_key: key_dtype") .Output("table_handle: resource") .Attr("container: string = ''") .Attr("shared_name: string = ''") -- cgit v1.2.3