diff options
Diffstat (limited to 'tensorflow/contrib/lookup/lookup_ops.py')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 81 |
1 files changed, 73 insertions, 8 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index f83765a48d..5abef822e8 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -292,8 +292,8 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation. - Data can be inserted by calling the insert method. It does not support - initialization via the init method. + Data can be inserted by calling the insert method and removed by calling the + remove method. It does not support initialization via the init method. Example usage: @@ -391,6 +391,34 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): with ops.colocate_with(self._table_ref): return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) + def remove(self, keys, name=None): + """Removes `keys` and its associated values from the table. + + If a key is not present in the table, it is silently ignored. + + Args: + keys: Keys to remove. Can be a tensor of any shape. Must match the table's + key type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + if keys.dtype != self._key_dtype: + raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % + (self._key_dtype, keys.dtype)) + + with ops.name_scope(name, "%s_lookup_table_remove" % self._name, + (self._table_ref, keys, self._default_value)) as name: + # pylint: disable=protected-access + op = gen_lookup_ops.lookup_table_remove_v2( + self._table_ref, keys, name=name) + + return op + def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -487,11 +515,11 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation using tensors as backing store. - Data can be inserted by calling the insert method. It does not support - initialization via the init method. + Data can be inserted by calling the insert method and removed by calling the + remove method. It does not support initialization via the init method. It uses "open addressing" with quadratic reprobing to resolve collisions. - Compared to `MutableHashTable` the insert and lookup operations in a + Compared to `MutableHashTable` the insert, remove and lookup operations in a `MutableDenseHashTable` are typically faster, but memory usage can be higher. However, `MutableDenseHashTable` does not require additional memory for temporary tensors created during checkpointing and restore operations. @@ -502,7 +530,9 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1, - empty_key=0) + empty_key=0, + deleted_key=-1) + sess.run(table.insert(keys, values)) out = table.lookup(query_keys) print(out.eval()) @@ -516,6 +546,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): value_dtype, default_value, empty_key, + deleted_key, initial_num_buckets=None, shared_name=None, name="MutableDenseHashTable", @@ -530,7 +561,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. empty_key: the key to use to represent empty buckets internally. Must not - be used in insert or lookup operations. + be used in insert, remove or lookup operations. initial_num_buckets: the initial number of buckets. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. @@ -538,9 +569,12 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): checkpoint: if True, the contents of the table are saved to and restored from checkpoints. If `shared_name` is empty for a checkpointed table, it is shared using the table node name. + deleted_key: the key to use to represent deleted buckets internally. Must + not be used in insert, remove or lookup operations and be different from + the empty_key. Returns: - A `MutableHashTable` object. + A `MutableDenseHashTable` object. Raises: ValueError: If checkpoint is True and no name was specified. @@ -555,6 +589,8 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor( empty_key, dtype=key_dtype, name="empty_key") + deleted_key = ops.convert_to_tensor( + deleted_key, dtype=key_dtype, name="deleted_key") executing_eagerly = context.executing_eagerly() if executing_eagerly and shared_name is None: # TODO(allenl): This will leak memory due to kernel caching by the @@ -564,6 +600,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): shared_name = "table_%d" % (ops.uid(),) self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, + deleted_key=deleted_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, value_dtype=value_dtype, @@ -648,6 +685,34 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): self._table_ref, keys, values, name=name) return op + def remove(self, keys, name=None): + """Removes `keys` and its associated values from the table. + + If a key is not present in the table, it is silently ignored. + + Args: + keys: Keys to remove. Can be a tensor of any shape. Must match the table's + key type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + if keys.dtype != self._key_dtype: + raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % + (self._key_dtype, keys.dtype)) + + with ops.name_scope(name, "%s_lookup_table_remove" % self._name, + (self._table_ref, keys, self._default_value)) as name: + # pylint: disable=protected-access + op = gen_lookup_ops.lookup_table_remove_v2( + self._table_ref, keys, name=name) + + return op + def export(self, name=None): """Returns tensors of all keys and values in the table. |