diff options
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/initializable_lookup_table.h | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 184 |
2 files changed, 182 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 424fe5df3c..a14d4967a5 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -51,6 +51,12 @@ class InitializableLookupTable : public LookupInterface { "Insert not supported by InitializableLookupTable implementations"); } + // Returns errors::Unimplemented. + Status Remove(OpKernelContext* ctx, const Tensor& keys) final { + return errors::Unimplemented( + "Remove not supported by InitializableLookupTable implementations"); + } + Status ExportValues(OpKernelContext* context) override { return errors::Unimplemented( "ExportValues not supported by InitializableLookupTable " diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index a495758861..0bc1ea77d6 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -89,6 +89,16 @@ class MutableHashTableOfScalars final : public LookupInterface { return DoInsert(false, keys, values); } + Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + const auto key_values = keys.flat<K>(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface { return DoInsert(false, keys, values); } + Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + const auto key_values = keys.flat<K>(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface { empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}), 0); + const Tensor* deleted_key_input; + OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input)); + OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()), + errors::InvalidArgument( + "Empty and deleted keys must have same shape, got shapes: ", + key_shape_.DebugString(), " and ", + deleted_key_input->shape().DebugString())); + deleted_key_ = PersistentTensor(*deleted_key_input); + deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>( + {1, key_shape_.num_elements()}), + 0); + + if (empty_key_hash_ == deleted_key_hash_) { + const int64 key_size = key_shape_.num_elements(); + const auto empty_key_matrix = + empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + OP_REQUIRES( + ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0), + errors::InvalidArgument("Empty and deleted keys cannot be equal")); + } + int64 initial_num_buckets; OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets", &initial_num_buckets)); @@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix<V>(); const auto empty_key_matrix = empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); const int64 bit_mask = num_buckets_ - 1; // TODO(andreasst): parallelize using work_sharder for (int64 i = 0; i < num_elements; ++i) { @@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface { return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface { return DoInsert(ctx, key, value, false); } + Status Remove(OpKernelContext* ctx, const Tensor& key) override + LOCKS_EXCLUDED(mu_) { + if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { + TensorShape expected_shape({key.dim_size(0)}); + expected_shape.AppendShape(key_shape_); + return errors::InvalidArgument("Expected key shape ", + expected_shape.DebugString(), " got ", + key.shape().DebugString()); + } + mutex_lock l(mu_); + return DoRemove(ctx, key); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); num_buckets_ = keys.dim_size(0); key_buckets_ = PersistentTensor(keys); value_buckets_ = PersistentTensor(values); - // Count the number of keys that are not the empty_key. This requires - // iterating through the whole table but that is OK as we only execute it - // during checkpoint restore. + // Count the number of keys that are not the empty_key or deleted_key. + // This requires iterating through the whole table but that is OK as we + // only execute it during checkpoint restore. num_entries_ = 0; const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped<K, 2>( {1, key_shape_.num_elements()}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>( + {1, key_shape_.num_elements()}); const auto key_buckets_tensor = key_buckets_.AccessTensor(ctx)->template matrix<K>(); for (int64 i = 0; i < num_buckets_; ++i) { - if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) { + if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) && + !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) { ++num_entries_; } } @@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, - bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + bool ignore_empty_and_deleted_key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); @@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix<V>(); const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); const int64 bit_mask = num_buckets_ - 1; for (int64 i = 0; i < num_elements; ++i) { const uint64 key_hash = HashKey(key_matrix, i); if (empty_key_hash_ == key_hash && IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { - if (ignore_empty_key) { + if (ignore_empty_and_deleted_key) { continue; } return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + if (ignore_empty_and_deleted_key) { + continue; + } + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface { } break; } - if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) || + IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor, + 0)) { ++num_entries_; for (int64 j = 0; j < key_size; ++j) { key_buckets_matrix(bucket_index, j) = @@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface { return Status::OK(); } + Status DoRemove(OpKernelContext* ctx, const Tensor& key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const int64 num_elements = key.dim_size(0); + const int64 key_size = key_shape_.num_elements(); + const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); + + auto key_buckets_matrix = + key_buckets_.AccessTensor(ctx)->template matrix<K>(); + const auto empty_key_tensor = + empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_flat = + deleted_key_.AccessTensor(ctx)->template flat<K>(); + const int64 bit_mask = num_buckets_ - 1; + for (int64 i = 0; i < num_elements; ++i) { + const uint64 key_hash = HashKey(key_matrix, i); + if (empty_key_hash_ == key_hash && + IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the empty_key as a table key is not allowed"); + } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } + int64 bucket_index = key_hash & bit_mask; + int64 num_probes = 0; + while (true) { + if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { + --num_entries_; + for (int64 j = 0; j < key_size; ++j) { + key_buckets_matrix(bucket_index, j) = + SubtleMustCopyIfIntegral(deleted_key_flat(j)); + } + break; + } + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + break; + } + ++num_probes; + bucket_index = + (bucket_index + num_probes) & bit_mask; // quadratic probing + if (num_probes >= num_buckets_) { + return errors::Internal( + "Internal error in MutableDenseHashTable remove"); + } + } + } + return Status::OK(); + } + Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (new_num_buckets < 4 || @@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface { PersistentTensor value_buckets_ GUARDED_BY(mu_); PersistentTensor empty_key_; uint64 empty_key_hash_; -}; + PersistentTensor deleted_key_; + uint64 deleted_key_hash_; +}; // namespace lookup } // namespace lookup @@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), LookupTableInsertOp); +// Table remove op. +class LookupTableRemoveOp : public OpKernel { + public: + explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor& key = ctx->input(1); + OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->Remove(ctx, key)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU), + LookupTableRemoveOp); + // Op that returns the size of the given table. class LookupTableSizeOp : public OpKernel { public: |