aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup/lookup_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lookup/lookup_ops.py')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py81
1 files changed, 73 insertions, 8 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index f83765a48d..5abef822e8 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -292,8 +292,8 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation.
- Data can be inserted by calling the insert method. It does not support
- initialization via the init method.
+ Data can be inserted by calling the insert method and removed by calling the
+ remove method. It does not support initialization via the init method.
Example usage:
@@ -391,6 +391,34 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
with ops.colocate_with(self._table_ref):
return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name)
+ def remove(self, keys, name=None):
+ """Removes `keys` and its associated values from the table.
+
+ If a key is not present in the table, it is silently ignored.
+
+ Args:
+ keys: Keys to remove. Can be a tensor of any shape. Must match the table's
+ key type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_remove" % self._name,
+ (self._table_ref, keys, self._default_value)) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops.lookup_table_remove_v2(
+ self._table_ref, keys, name=name)
+
+ return op
+
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -487,11 +515,11 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation using tensors as backing store.
- Data can be inserted by calling the insert method. It does not support
- initialization via the init method.
+ Data can be inserted by calling the insert method and removed by calling the
+ remove method. It does not support initialization via the init method.
It uses "open addressing" with quadratic reprobing to resolve collisions.
- Compared to `MutableHashTable` the insert and lookup operations in a
+ Compared to `MutableHashTable` the insert, remove and lookup operations in a
`MutableDenseHashTable` are typically faster, but memory usage can be higher.
However, `MutableDenseHashTable` does not require additional memory for
temporary tensors created during checkpointing and restore operations.
@@ -502,7 +530,9 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64,
value_dtype=tf.int64,
default_value=-1,
- empty_key=0)
+ empty_key=0,
+ deleted_key=-1)
+
sess.run(table.insert(keys, values))
out = table.lookup(query_keys)
print(out.eval())
@@ -516,6 +546,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
value_dtype,
default_value,
empty_key,
+ deleted_key,
initial_num_buckets=None,
shared_name=None,
name="MutableDenseHashTable",
@@ -530,7 +561,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
value_dtype: the type of the value tensors.
default_value: The value to use if a key is missing in the table.
empty_key: the key to use to represent empty buckets internally. Must not
- be used in insert or lookup operations.
+ be used in insert, remove or lookup operations.
initial_num_buckets: the initial number of buckets.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
@@ -538,9 +569,12 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
checkpoint: if True, the contents of the table are saved to and restored
from checkpoints. If `shared_name` is empty for a checkpointed table, it
is shared using the table node name.
+ deleted_key: the key to use to represent deleted buckets internally. Must
+ not be used in insert, remove or lookup operations and be different from
+ the empty_key.
Returns:
- A `MutableHashTable` object.
+ A `MutableDenseHashTable` object.
Raises:
ValueError: If checkpoint is True and no name was specified.
@@ -555,6 +589,8 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
+ deleted_key = ops.convert_to_tensor(
+ deleted_key, dtype=key_dtype, name="deleted_key")
executing_eagerly = context.executing_eagerly()
if executing_eagerly and shared_name is None:
# TODO(allenl): This will leak memory due to kernel caching by the
@@ -564,6 +600,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
shared_name = "table_%d" % (ops.uid(),)
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
+ deleted_key=deleted_key,
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
value_dtype=value_dtype,
@@ -648,6 +685,34 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
self._table_ref, keys, values, name=name)
return op
+ def remove(self, keys, name=None):
+ """Removes `keys` and its associated values from the table.
+
+ If a key is not present in the table, it is silently ignored.
+
+ Args:
+ keys: Keys to remove. Can be a tensor of any shape. Must match the table's
+ key type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_remove" % self._name,
+ (self._table_ref, keys, self._default_value)) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops.lookup_table_remove_v2(
+ self._table_ref, keys, name=name)
+
+ return op
+
def export(self, name=None):
"""Returns tensors of all keys and values in the table.