aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_op.cc')
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc184
1 files changed, 176 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index a495758861..0bc1ea77d6 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -89,6 +89,16 @@ class MutableHashTableOfScalars final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface {
empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
0);
+ const Tensor* deleted_key_input;
+ OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
+ OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
+ errors::InvalidArgument(
+ "Empty and deleted keys must have same shape, got shapes: ",
+ key_shape_.DebugString(), " and ",
+ deleted_key_input->shape().DebugString()));
+ deleted_key_ = PersistentTensor(*deleted_key_input);
+ deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
+ {1, key_shape_.num_elements()}),
+ 0);
+
+ if (empty_key_hash_ == deleted_key_hash_) {
+ const int64 key_size = key_shape_.num_elements();
+ const auto empty_key_matrix =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ OP_REQUIRES(
+ ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
+ errors::InvalidArgument("Empty and deleted keys cannot be equal"));
+ }
+
int64 initial_num_buckets;
OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
&initial_num_buckets));
@@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_matrix =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
// TODO(andreasst): parallelize using work_sharder
for (int64 i = 0; i < num_elements; ++i) {
@@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface {
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface {
return DoInsert(ctx, key, value, false);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& key) override
+ LOCKS_EXCLUDED(mu_) {
+ if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
+ TensorShape expected_shape({key.dim_size(0)});
+ expected_shape.AppendShape(key_shape_);
+ return errors::InvalidArgument("Expected key shape ",
+ expected_shape.DebugString(), " got ",
+ key.shape().DebugString());
+ }
+ mutex_lock l(mu_);
+ return DoRemove(ctx, key);
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
num_buckets_ = keys.dim_size(0);
key_buckets_ = PersistentTensor(keys);
value_buckets_ = PersistentTensor(values);
- // Count the number of keys that are not the empty_key. This requires
- // iterating through the whole table but that is OK as we only execute it
- // during checkpoint restore.
+ // Count the number of keys that are not the empty_key or deleted_key.
+ // This requires iterating through the whole table but that is OK as we
+ // only execute it during checkpoint restore.
num_entries_ = 0;
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
{1, key_shape_.num_elements()});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>(
+ {1, key_shape_.num_elements()});
const auto key_buckets_tensor =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
for (int64 i = 0; i < num_buckets_; ++i) {
- if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) {
+ if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
+ !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
++num_entries_;
}
}
@@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
- bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ bool ignore_empty_and_deleted_key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
@@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
for (int64 i = 0; i < num_elements; ++i) {
const uint64 key_hash = HashKey(key_matrix, i);
if (empty_key_hash_ == key_hash &&
IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
- if (ignore_empty_key) {
+ if (ignore_empty_and_deleted_key) {
continue;
}
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ if (ignore_empty_and_deleted_key) {
+ continue;
+ }
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface {
}
break;
}
- if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
+ IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
+ 0)) {
++num_entries_;
for (int64 j = 0; j < key_size; ++j) {
key_buckets_matrix(bucket_index, j) =
@@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface {
return Status::OK();
}
+ Status DoRemove(OpKernelContext* ctx, const Tensor& key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const int64 num_elements = key.dim_size(0);
+ const int64 key_size = key_shape_.num_elements();
+ const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
+
+ auto key_buckets_matrix =
+ key_buckets_.AccessTensor(ctx)->template matrix<K>();
+ const auto empty_key_tensor =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_flat =
+ deleted_key_.AccessTensor(ctx)->template flat<K>();
+ const int64 bit_mask = num_buckets_ - 1;
+ for (int64 i = 0; i < num_elements; ++i) {
+ const uint64 key_hash = HashKey(key_matrix, i);
+ if (empty_key_hash_ == key_hash &&
+ IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the empty_key as a table key is not allowed");
+ }
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
+ int64 bucket_index = key_hash & bit_mask;
+ int64 num_probes = 0;
+ while (true) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
+ --num_entries_;
+ for (int64 j = 0; j < key_size; ++j) {
+ key_buckets_matrix(bucket_index, j) =
+ SubtleMustCopyIfIntegral(deleted_key_flat(j));
+ }
+ break;
+ }
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ break;
+ }
+ ++num_probes;
+ bucket_index =
+ (bucket_index + num_probes) & bit_mask; // quadratic probing
+ if (num_probes >= num_buckets_) {
+ return errors::Internal(
+ "Internal error in MutableDenseHashTable remove");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (new_num_buckets < 4 ||
@@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface {
PersistentTensor value_buckets_ GUARDED_BY(mu_);
PersistentTensor empty_key_;
uint64 empty_key_hash_;
-};
+ PersistentTensor deleted_key_;
+ uint64 deleted_key_hash_;
+}; // namespace lookup
} // namespace lookup
@@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
LookupTableInsertOp);
+// Table remove op.
+class LookupTableRemoveOp : public OpKernel {
+ public:
+ explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
+
+ const Tensor& key = ctx->input(1);
+ OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
+
+ int64 memory_used_before = 0;
+ if (ctx->track_allocations()) {
+ memory_used_before = table->MemoryUsed();
+ }
+ OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
+ if (ctx->track_allocations()) {
+ ctx->record_persistent_memory_allocation(table->MemoryUsed() -
+ memory_used_before);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
+ LookupTableRemoveOp);
+
// Op that returns the size of the given table.
class LookupTableSizeOp : public OpKernel {
public: