aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 16:23:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 16:39:23 -0700
commit6c391166b8b6ba43d2b0151e6fb9cf14864131a2 (patch)
treec8c3c9eadade00f1a4e6cec2024e2a15bfd0b948 /tensorflow
parent2f5ebc0ea5e6d500ea8cd925234c569d6b32fd4e (diff)
Add 'remove' operation to MutableHashTable and MutableDenseHashTable.
PiperOrigin-RevId: 216443201
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py3
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py2
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py81
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py336
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py19
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management.py1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt24
-rw-r--r--tensorflow/core/framework/lookup_interface.cc8
-rw-r--r--tensorflow/core/framework/lookup_interface.h17
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h6
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc184
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt20
-rw-r--r--tensorflow/core/ops/lookup_ops.cc14
15 files changed, 643 insertions, 86 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 48ac429701..b5099a0bf6 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -152,7 +152,8 @@ class SdcaModel(object):
default_value=[0.0, 0.0, 0.0, 0.0],
# SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe
# empty_key (that will never collide with actual payloads).
- empty_key=[0, 0])
+ empty_key=[0, 0],
+ deleted_key=[1, 1])
summary.scalar('approximate_duality_gap', self.approximate_duality_gap())
summary.scalar('examples_seen', self._hashtable.size())
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
index 5015fb0848..44a869f7c2 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
@@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface):
value_dtype,
default_value,
empty_key,
+ deleted_key,
num_shards=1,
checkpoint=True,
name='ShardedMutableHashTable'):
@@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface):
value_dtype=value_dtype,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
checkpoint=checkpoint,
name='%s-%d-of-%d' % (name, i + 1, num_shards)))
self._table_shards = table_shards
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index 553b116a3b..2b56d0fa3a 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.cached_session():
default_val = -1
empty_key = 0
+ deleted_key = -1
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = ShardedMutableDenseHashTable(
@@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.int64,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
@@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
+ deleted_key = [1, 0]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
dtypes.int64)
values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]],
@@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.float32,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
@@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testExportSharded(self):
with self.cached_session():
empty_key = -2
+ deleted_key = -3
default_val = -1
num_shards = 2
keys = constant_op.constant([10, 11, 12], dtypes.int64)
@@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
dtypes.int64,
default_val,
empty_key,
+ deleted_key,
num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index f83765a48d..5abef822e8 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -292,8 +292,8 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation.
- Data can be inserted by calling the insert method. It does not support
- initialization via the init method.
+ Data can be inserted by calling the insert method and removed by calling the
+ remove method. It does not support initialization via the init method.
Example usage:
@@ -391,6 +391,34 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
with ops.colocate_with(self._table_ref):
return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name)
+ def remove(self, keys, name=None):
+ """Removes `keys` and its associated values from the table.
+
+ If a key is not present in the table, it is silently ignored.
+
+ Args:
+ keys: Keys to remove. Can be a tensor of any shape. Must match the table's
+ key type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_remove" % self._name,
+ (self._table_ref, keys, self._default_value)) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops.lookup_table_remove_v2(
+ self._table_ref, keys, name=name)
+
+ return op
+
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -487,11 +515,11 @@ class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation using tensors as backing store.
- Data can be inserted by calling the insert method. It does not support
- initialization via the init method.
+ Data can be inserted by calling the insert method and removed by calling the
+ remove method. It does not support initialization via the init method.
It uses "open addressing" with quadratic reprobing to resolve collisions.
- Compared to `MutableHashTable` the insert and lookup operations in a
+ Compared to `MutableHashTable` the insert, remove and lookup operations in a
`MutableDenseHashTable` are typically faster, but memory usage can be higher.
However, `MutableDenseHashTable` does not require additional memory for
temporary tensors created during checkpointing and restore operations.
@@ -502,7 +530,9 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64,
value_dtype=tf.int64,
default_value=-1,
- empty_key=0)
+ empty_key=0,
+ deleted_key=-1)
+
sess.run(table.insert(keys, values))
out = table.lookup(query_keys)
print(out.eval())
@@ -516,6 +546,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
value_dtype,
default_value,
empty_key,
+ deleted_key,
initial_num_buckets=None,
shared_name=None,
name="MutableDenseHashTable",
@@ -530,7 +561,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
value_dtype: the type of the value tensors.
default_value: The value to use if a key is missing in the table.
empty_key: the key to use to represent empty buckets internally. Must not
- be used in insert or lookup operations.
+ be used in insert, remove or lookup operations.
initial_num_buckets: the initial number of buckets.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
@@ -538,9 +569,12 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
checkpoint: if True, the contents of the table are saved to and restored
from checkpoints. If `shared_name` is empty for a checkpointed table, it
is shared using the table node name.
+ deleted_key: the key to use to represent deleted buckets internally. Must
+ not be used in insert, remove or lookup operations and be different from
+ the empty_key.
Returns:
- A `MutableHashTable` object.
+ A `MutableDenseHashTable` object.
Raises:
ValueError: If checkpoint is True and no name was specified.
@@ -555,6 +589,8 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
+ deleted_key = ops.convert_to_tensor(
+ deleted_key, dtype=key_dtype, name="deleted_key")
executing_eagerly = context.executing_eagerly()
if executing_eagerly and shared_name is None:
# TODO(allenl): This will leak memory due to kernel caching by the
@@ -564,6 +600,7 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
shared_name = "table_%d" % (ops.uid(),)
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
+ deleted_key=deleted_key,
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
value_dtype=value_dtype,
@@ -648,6 +685,34 @@ class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
self._table_ref, keys, values, name=name)
return op
+ def remove(self, keys, name=None):
+ """Removes `keys` and its associated values from the table.
+
+ If a key is not present in the table, it is silently ignored.
+
+ Args:
+ keys: Keys to remove. Can be a tensor of any shape. Must match the table's
+ key type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_remove" % self._name,
+ (self._table_ref, keys, self._default_value)) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops.lookup_table_remove_v2(
+ self._table_ref, keys, name=name)
+
+ return op
+
def export(self, name=None):
"""Returns tensors of all keys and values in the table.
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 9e9345e875..35b0d1bc44 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -303,13 +303,17 @@ class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
with self.cached_session():
default_val = -1
- keys = constant_op.constant(["brain", "salad", "surgery"])
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["tarkus", "tank"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -472,13 +476,18 @@ class MutableHashTableOpTest(test.TestCase):
def testMutableHashTableOfTensors(self):
with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
- keys = constant_op.constant(["brain", "salad", "surgery"])
- values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
+ values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]],
+ dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["tarkus", "tank"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -624,6 +633,26 @@ class MutableHashTableOpTest(test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, 3, -1], result)
+ def testMutableHashTableRemoveHighRank(self):
+ with self.test_session():
+ default_val = -1
+ keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
+ values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["salad", "tarkus"])
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, -1, 3, -1], result)
+
def testMutableHashTableOfTensorsFindHighRank(self):
with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
@@ -645,6 +674,30 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
+ def testMutableHashTableOfTensorsRemoveHighRank(self):
+ with self.test_session():
+ default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
+ keys = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
+ dtypes.int64)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ remove_string = constant_op.constant([["brain", "tank"]])
+ table.remove(remove_string).run()
+ self.assertAllEqual(2, table.size().eval())
+
+ input_string = constant_op.constant([["brain", "salad"],
+ ["surgery", "tank"]])
+ output = table.lookup(input_string)
+ self.assertAllEqual([2, 2, 3], output.get_shape())
+
+ result = output.eval()
+ self.assertAllEqual(
+ [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result)
+
def testMultipleMutableHashTables(self):
with self.cached_session() as sess:
default_val = -1
@@ -792,13 +845,22 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
@@ -806,17 +868,26 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([3], output.get_shape())
result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ self.assertAllEqual([0, -1, -1], result)
def testBasicBool(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([True, True, True], dtypes.bool)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([True, True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.bool, default_value=False, empty_key=0)
+ dtypes.int64,
+ dtypes.bool,
+ default_value=False,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([11, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
@@ -824,14 +895,30 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([3], output.get_shape())
result = output.eval()
- self.assertAllEqual([True, True, False], result)
+ self.assertAllEqual([False, True, False], result)
+
+ def testSameEmptyAndDeletedKey(self):
+ with self.cached_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=42)
+ self.assertAllEqual(0, table.size().eval())
def testLookupUnknownShape(self):
with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
@@ -844,45 +931,60 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapStringToFloat(self):
with self.cached_session():
- keys = constant_op.constant(["a", "b", "c"], dtypes.string)
- values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
+
+ keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string)
+ values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
table = lookup.MutableDenseHashTable(
dtypes.string,
dtypes.float32,
default_value=default_value,
- empty_key="")
+ empty_key="",
+ deleted_key="$")
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant(["b", "e"])
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
- input_string = constant_op.constant(["a", "b", "d"], dtypes.string)
+ input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllClose([0, 1.1, -1.5], result)
+ self.assertAllClose([0, -1.5, 3.3, -1.5], result)
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
table = lookup.MutableDenseHashTable(
- dtypes.int64, float_dtype, default_value=default_value, empty_key=0)
+ dtypes.int64,
+ float_dtype,
+ default_value=default_value,
+ empty_key=0,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ remove_string = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(remove_string).run()
self.assertAllEqual(3, table.size().eval())
- input_string = constant_op.constant([11, 12, 15], dtypes.int64)
+ input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllClose([0, 1.1, -1.5], result)
+ self.assertAllClose([0, -1.5, 3.3, -1.5], result)
def testVectorValues(self):
with self.cached_session():
@@ -895,6 +997,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
@@ -908,26 +1011,35 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
- input_string = constant_op.constant([11, 12, 15], dtypes.int64)
+ remove_string = constant_op.constant([12, 16], dtypes.int64)
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(8, len(table.export()[0].eval()))
+
+ input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual(
- [3, 4], output.shape, msg="Saw shape: %s" % output.shape)
+ self.assertAllEqual([4, 4],
+ output.shape,
+ msg="Saw shape: %s" % output.shape)
result = output.eval()
- self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],
- result)
+ self.assertAllEqual(
+ [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]],
+ result)
def testVectorKeys(self):
with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
@@ -940,13 +1052,18 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
- input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]],
+ remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64)
+ table.remove(remove_string).run()
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(8, len(table.export()[0].eval()))
+
+ input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]],
dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([4], output.get_shape())
result = output.eval()
- self.assertAllEqual([10, 11, -1], result)
+ self.assertAllEqual([10, -1, 12, -1], result)
def testResize(self):
with self.cached_session():
@@ -957,6 +1074,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
@@ -964,31 +1082,42 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(4, len(table.export()[0].eval()))
- keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
- values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
+ keys2 = constant_op.constant([12, 99], dtypes.int64)
+ table.remove(keys2).run()
+ self.assertAllEqual(2, table.size().eval())
+ self.assertAllEqual(4, len(table.export()[0].eval()))
+
+ keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
+ values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
- table.insert(keys2, values2).run()
- self.assertAllEqual(7, table.size().eval())
+ table.insert(keys3, values3).run()
+ self.assertAllEqual(6, table.size().eval())
self.assertAllEqual(16, len(table.export()[0].eval()))
- keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
+ keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
dtypes.int64)
- output = table.lookup(keys3)
- self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
+ output = table.lookup(keys4)
+ self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
with self.cached_session():
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([1, 2, 3], dtypes.int64)
+
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([1, 2, 3, 4], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=100,
+ deleted_key=200,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+
+ keys2 = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
exported_keys, exported_values = table.export()
@@ -1005,8 +1134,8 @@ class MutableDenseHashTableOpTest(test.TestCase):
pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
# sort by key
pairs = pairs[pairs[:, 0].argsort()]
- self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0],
- [100, 0], [100, 0], [100, 0]], pairs)
+ self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0],
+ [100, 0], [100, 0], [200, 2]], pairs)
def testSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
@@ -1015,13 +1144,15 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
default_value = -1
empty_key = 0
- keys = constant_op.constant([11, 12, 13], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ deleted_key = -1
+ keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1030,6 +1161,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([12, 15], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1043,6 +1179,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1062,7 +1199,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())
@test_util.run_in_graph_and_eager_modes
def testObjectSaveRestore(self):
@@ -1071,6 +1208,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
default_value = -1
empty_key = 0
+ deleted_key = -1
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
save_table = lookup.MutableDenseHashTable(
@@ -1078,6 +1216,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1097,6 +1236,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1124,14 +1264,18 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-2, -3], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
- keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
- values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
+ keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
+ dtypes.int64)
+ values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]],
+ dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@@ -1140,6 +1284,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1149,12 +1298,14 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-2, -3], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
@@ -1184,14 +1335,17 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
- keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
- values = constant_op.constant([0, 1, 2], dtypes.int64)
+ keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
+ dtypes.int64)
+ values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t2",
checkpoint=True,
initial_num_buckets=32)
@@ -1200,6 +1354,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
+ self.assertAllEqual(4, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64)
+ table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
@@ -1209,12 +1368,14 @@ class MutableDenseHashTableOpTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
+ deleted_key = constant_op.constant([-1, -1], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name="t2",
checkpoint=True,
initial_num_buckets=64)
@@ -1235,7 +1396,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant(
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
+ self.assertAllEqual([0, 1, -1, 3, -1], output.eval())
def testReprobe(self):
with self.cached_session():
@@ -1248,6 +1409,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=0,
+ deleted_key=-1,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
@@ -1267,7 +1429,11 @@ class MutableDenseHashTableOpTest(test.TestCase):
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=12)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=12,
+ deleted_key=-1)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
@@ -1283,19 +1449,35 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testErrors(self):
with self.cached_session():
table = lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=0,
+ deleted_key=-1)
# Inserting the empty key returns an error
- keys = constant_op.constant([11, 0], dtypes.int64)
- values = constant_op.constant([0, 1], dtypes.int64)
+ keys1 = constant_op.constant([11, 0], dtypes.int64)
+ values1 = constant_op.constant([0, 1], dtypes.int64)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
- table.insert(keys, values).run()
+ table.insert(keys1, values1).run()
# Looking up the empty key returns an error
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
- table.lookup(keys).eval()
+ table.lookup(keys1).eval()
+
+ # Inserting the deleted key returns an error
+ keys2 = constant_op.constant([11, -1], dtypes.int64)
+ values2 = constant_op.constant([0, 1], dtypes.int64)
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table.insert(keys2, values2).run()
+
+ # Looking up the empty key returns an error
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "deleted_key"):
+ table.lookup(keys2).eval()
# Arbitrary tensors of keys are not supported
keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
@@ -1312,11 +1494,43 @@ class MutableDenseHashTableOpTest(test.TestCase):
dtypes.int64,
default_value=-1,
empty_key=17,
+ deleted_key=-1,
initial_num_buckets=12)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Number of buckets must be"):
self.assertAllEqual(0, table2.size().eval())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Empty and deleted keys must have same shape"):
+ table3 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=[1, 2])
+ self.assertAllEqual(0, table3.size().eval())
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "Empty and deleted keys cannot be equal"):
+ table4 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=42,
+ deleted_key=42)
+ self.assertAllEqual(0, table4.size().eval())
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "Empty and deleted keys cannot be equal"):
+ table5 = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=-1,
+ empty_key=[1, 2, 3],
+ deleted_key=[1, 2, 3])
+ self.assertAllEqual(0, table5.size().eval())
+
class IndexTableFromFile(test.TestCase):
@@ -2558,7 +2772,11 @@ class MutableDenseHashTableBenchmark(MutableHashTableBenchmark):
def _create_table(self):
return lookup.MutableDenseHashTable(
- dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1)
+ dtypes.int64,
+ dtypes.float32,
+ default_value=0.0,
+ empty_key=-1,
+ deleted_key=-2)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 03da2b82e5..9c585fe6a7 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -543,20 +543,25 @@ class TupleOfTensorsLookup(lookup.LookupInterface):
overhead.
"""
- def __init__(
- self, key_dtype, default_values, empty_key, name, checkpoint=True):
+ def __init__(self,
+ key_dtype,
+ default_values,
+ empty_key,
+ deleted_key,
+ name,
+ checkpoint=True):
default_values_flat = nest.flatten(default_values)
- self._hash_tables = nest.pack_sequence_as(
- default_values,
- [TensorValuedMutableDenseHashTable(
+ self._hash_tables = nest.pack_sequence_as(default_values, [
+ TensorValuedMutableDenseHashTable(
key_dtype=key_dtype,
value_dtype=default_value.dtype.base_dtype,
default_value=default_value,
empty_key=empty_key,
+ deleted_key=deleted_key,
name=name + "_{}".format(table_number),
checkpoint=checkpoint)
- for table_number, default_value
- in enumerate(default_values_flat)])
+ for table_number, default_value in enumerate(default_values_flat)
+ ])
self._name = name
def lookup(self, keys):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index c0de42b15b..91265b9b2e 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -223,10 +223,12 @@ class TestLookupTable(test.TestCase):
hash_table = math_utils.TupleOfTensorsLookup(
key_dtype=dtypes.int64,
default_values=[[
- array_ops.ones([3, 2], dtype=dtypes.float32), array_ops.zeros(
- [5], dtype=dtypes.float64)
- ], array_ops.ones([7, 7], dtype=dtypes.int64)],
+ array_ops.ones([3, 2], dtype=dtypes.float32),
+ array_ops.zeros([5], dtype=dtypes.float64)
+ ],
+ array_ops.ones([7, 7], dtype=dtypes.int64)],
empty_key=-1,
+ deleted_key=-2,
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management.py b/tensorflow/contrib/timeseries/python/timeseries/state_management.py
index 13eecd4d82..138406c616 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management.py
@@ -149,6 +149,7 @@ class ChainingStateManager(_OverridableStateManager):
key_dtype=dtypes.int64,
default_values=self._start_state,
empty_key=-1,
+ deleted_key=-2,
name="cached_states",
checkpoint=self._checkpoint_state)
diff --git a/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt
new file mode 100644
index 0000000000..333fe6f4b2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LookupTableRemoveV2"
+ visibility: HIDDEN
+ endpoint {
+ name: "LookupTableRemove"
+ }
+ in_arg {
+ name: "table_handle"
+ description: <<END
+Handle to the table.
+END
+ }
+ in_arg {
+ name: "keys"
+ description: <<END
+Any shape. Keys of the elements to remove.
+END
+ }
+ summary: "Removes keys and its associated values from a table."
+ description: <<END
+The tensor `keys` must of the same type as the keys of the table. Keys not
+already in the table are silently ignored.
+END
+}
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
index bf3204ea6e..117adbf65c 100644
--- a/tensorflow/core/framework/lookup_interface.cc
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -71,6 +71,14 @@ Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys,
return CheckKeyAndValueTensorsHelper(keys, values);
}
+Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) {
+ if (keys.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", keys.dtype());
+ }
+ return CheckKeyShape(keys.shape());
+}
+
Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 0622dd06cb..d33945fd1b 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -64,6 +64,17 @@ class LookupInterface : public ResourceBase {
virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) = 0;
+ // Removes elements from the table.
+ // This method is only implemented in mutable tables that can be updated over
+ // the execution of the graph. It returns Status::NotImplemented for read-only
+ // tables that are initialized once before they can be looked up.
+
+ // Returns the following statuses:
+ // - OK: when the remove finishes successfully.
+ // - InvalidArgument: if any of the preconditions on the lookup key fails.
+ // - Unimplemented: if the table does not support removals.
+ virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0;
+
// Returns the number of elements in the table.
virtual size_t size() const = 0;
@@ -107,6 +118,12 @@ class LookupInterface : public ResourceBase {
virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
const Tensor& values);
+ // Check format of the key tensor for the Remove function.
+ // Returns OK if all the following requirements are satisfied, otherwise it
+ // returns InvalidArgument:
+ // - DataType of the tensor keys equals to the table key_dtype
+ virtual Status CheckKeyTensorForRemove(const Tensor& keys);
+
// Check the arguments of a find operation. Returns OK if all the following
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 424fe5df3c..a14d4967a5 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -51,6 +51,12 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
+ // Returns errors::Unimplemented.
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
+ return errors::Unimplemented(
+ "Remove not supported by InitializableLookupTable implementations");
+ }
+
Status ExportValues(OpKernelContext* context) override {
return errors::Unimplemented(
"ExportValues not supported by InitializableLookupTable "
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index a495758861..0bc1ea77d6 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -89,6 +89,16 @@ class MutableHashTableOfScalars final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface {
empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
0);
+ const Tensor* deleted_key_input;
+ OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
+ OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
+ errors::InvalidArgument(
+ "Empty and deleted keys must have same shape, got shapes: ",
+ key_shape_.DebugString(), " and ",
+ deleted_key_input->shape().DebugString()));
+ deleted_key_ = PersistentTensor(*deleted_key_input);
+ deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
+ {1, key_shape_.num_elements()}),
+ 0);
+
+ if (empty_key_hash_ == deleted_key_hash_) {
+ const int64 key_size = key_shape_.num_elements();
+ const auto empty_key_matrix =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ OP_REQUIRES(
+ ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
+ errors::InvalidArgument("Empty and deleted keys cannot be equal"));
+ }
+
int64 initial_num_buckets;
OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
&initial_num_buckets));
@@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_matrix =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
// TODO(andreasst): parallelize using work_sharder
for (int64 i = 0; i < num_elements; ++i) {
@@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface {
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface {
return DoInsert(ctx, key, value, false);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& key) override
+ LOCKS_EXCLUDED(mu_) {
+ if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
+ TensorShape expected_shape({key.dim_size(0)});
+ expected_shape.AppendShape(key_shape_);
+ return errors::InvalidArgument("Expected key shape ",
+ expected_shape.DebugString(), " got ",
+ key.shape().DebugString());
+ }
+ mutex_lock l(mu_);
+ return DoRemove(ctx, key);
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
num_buckets_ = keys.dim_size(0);
key_buckets_ = PersistentTensor(keys);
value_buckets_ = PersistentTensor(values);
- // Count the number of keys that are not the empty_key. This requires
- // iterating through the whole table but that is OK as we only execute it
- // during checkpoint restore.
+ // Count the number of keys that are not the empty_key or deleted_key.
+ // This requires iterating through the whole table but that is OK as we
+ // only execute it during checkpoint restore.
num_entries_ = 0;
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
{1, key_shape_.num_elements()});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>(
+ {1, key_shape_.num_elements()});
const auto key_buckets_tensor =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
for (int64 i = 0; i < num_buckets_; ++i) {
- if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) {
+ if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
+ !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
++num_entries_;
}
}
@@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
- bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ bool ignore_empty_and_deleted_key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
@@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
for (int64 i = 0; i < num_elements; ++i) {
const uint64 key_hash = HashKey(key_matrix, i);
if (empty_key_hash_ == key_hash &&
IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
- if (ignore_empty_key) {
+ if (ignore_empty_and_deleted_key) {
continue;
}
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ if (ignore_empty_and_deleted_key) {
+ continue;
+ }
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface {
}
break;
}
- if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
+ IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
+ 0)) {
++num_entries_;
for (int64 j = 0; j < key_size; ++j) {
key_buckets_matrix(bucket_index, j) =
@@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface {
return Status::OK();
}
+ Status DoRemove(OpKernelContext* ctx, const Tensor& key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const int64 num_elements = key.dim_size(0);
+ const int64 key_size = key_shape_.num_elements();
+ const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
+
+ auto key_buckets_matrix =
+ key_buckets_.AccessTensor(ctx)->template matrix<K>();
+ const auto empty_key_tensor =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_flat =
+ deleted_key_.AccessTensor(ctx)->template flat<K>();
+ const int64 bit_mask = num_buckets_ - 1;
+ for (int64 i = 0; i < num_elements; ++i) {
+ const uint64 key_hash = HashKey(key_matrix, i);
+ if (empty_key_hash_ == key_hash &&
+ IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the empty_key as a table key is not allowed");
+ }
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
+ int64 bucket_index = key_hash & bit_mask;
+ int64 num_probes = 0;
+ while (true) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
+ --num_entries_;
+ for (int64 j = 0; j < key_size; ++j) {
+ key_buckets_matrix(bucket_index, j) =
+ SubtleMustCopyIfIntegral(deleted_key_flat(j));
+ }
+ break;
+ }
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ break;
+ }
+ ++num_probes;
+ bucket_index =
+ (bucket_index + num_probes) & bit_mask; // quadratic probing
+ if (num_probes >= num_buckets_) {
+ return errors::Internal(
+ "Internal error in MutableDenseHashTable remove");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (new_num_buckets < 4 ||
@@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface {
PersistentTensor value_buckets_ GUARDED_BY(mu_);
PersistentTensor empty_key_;
uint64 empty_key_hash_;
-};
+ PersistentTensor deleted_key_;
+ uint64 deleted_key_hash_;
+}; // namespace lookup
} // namespace lookup
@@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
LookupTableInsertOp);
+// Table remove op.
+class LookupTableRemoveOp : public OpKernel {
+ public:
+ explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
+
+ const Tensor& key = ctx->input(1);
+ OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
+
+ int64 memory_used_before = 0;
+ if (ctx->track_allocations()) {
+ memory_used_before = table->MemoryUsed();
+ }
+ OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
+ if (ctx->track_allocations()) {
+ ctx->record_persistent_memory_allocation(table->MemoryUsed() -
+ memory_used_before);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
+ LookupTableRemoveOp);
+
// Op that returns the size of the given table.
class LookupTableSizeOp : public OpKernel {
public:
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index cfb1055d3c..415e15b720 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -30321,6 +30321,22 @@ op {
is_stateful: true
}
op {
+ name: "LookupTableRemoveV2"
+ input_arg {
+ name: "table_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "keys"
+ type_attr: "Tin"
+ }
+ attr {
+ name: "Tin"
+ type: "type"
+ }
+ is_stateful: true
+}
+op {
name: "LookupTableSize"
input_arg {
name: "table_handle"
@@ -36706,6 +36722,10 @@ op {
name: "empty_key"
type_attr: "key_dtype"
}
+ input_arg {
+ name: "deleted_key"
+ type_attr: "key_dtype"
+ }
output_arg {
name: "table_handle"
type: DT_RESOURCE
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 72a77be70d..a0987cd982 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -214,6 +214,19 @@ REGISTER_OP("LookupTableInsertV2")
return Status::OK();
});
+REGISTER_OP("LookupTableRemoveV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Attr("Tin: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
+
+ // TODO(turboale): Validate keys shape.
+ return Status::OK();
+ });
+
REGISTER_OP("LookupTableSize")
.Input("table_handle: Ref(string)")
.Output("size: int64")
@@ -407,6 +420,7 @@ REGISTER_OP("MutableDenseHashTable")
REGISTER_OP("MutableDenseHashTableV2")
.Input("empty_key: key_dtype")
+ .Input("deleted_key: key_dtype")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")