diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 16:23:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 16:39:23 -0700 |
commit | 6c391166b8b6ba43d2b0151e6fb9cf14864131a2 (patch) | |
tree | c8c3c9eadade00f1a4e6cec2024e2a15bfd0b948 /tensorflow/core | |
parent | 2f5ebc0ea5e6d500ea8cd925234c569d6b32fd4e (diff) |
Add 'remove' operation to MutableHashTable and MutableDenseHashTable.
PiperOrigin-RevId: 216443201
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt | 24 | ||||
-rw-r--r-- | tensorflow/core/framework/lookup_interface.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/lookup_interface.h | 17 | ||||
-rw-r--r-- | tensorflow/core/kernels/initializable_lookup_table.h | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 184 | ||||
-rw-r--r-- | tensorflow/core/ops/compat/ops_history.v1.pbtxt | 20 | ||||
-rw-r--r-- | tensorflow/core/ops/lookup_ops.cc | 14 |
7 files changed, 265 insertions, 8 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt new file mode 100644 index 0000000000..333fe6f4b2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LookupTableRemoveV2" + visibility: HIDDEN + endpoint { + name: "LookupTableRemove" + } + in_arg { + name: "table_handle" + description: <<END +Handle to the table. +END + } + in_arg { + name: "keys" + description: <<END +Any shape. Keys of the elements to remove. +END + } + summary: "Removes keys and its associated values from a table." + description: <<END +The tensor `keys` must of the same type as the keys of the table. Keys not +already in the table are silently ignored. +END +} diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index bf3204ea6e..117adbf65c 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -71,6 +71,14 @@ Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys, return CheckKeyAndValueTensorsHelper(keys, values); } +Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) { + if (keys.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", keys.dtype()); + } + return CheckKeyShape(keys.shape()); +} + Status LookupInterface::CheckFindArguments(const Tensor& key, const Tensor& default_value) { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value)); diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index 0622dd06cb..d33945fd1b 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -64,6 +64,17 @@ class LookupInterface : public ResourceBase { virtual Status Insert(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) = 0; + // Removes elements from the table. + // This method is only implemented in mutable tables that can be updated over + // the execution of the graph. It returns Status::NotImplemented for read-only + // tables that are initialized once before they can be looked up. + + // Returns the following statuses: + // - OK: when the remove finishes successfully. + // - InvalidArgument: if any of the preconditions on the lookup key fails. + // - Unimplemented: if the table does not support removals. + virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; + // Returns the number of elements in the table. virtual size_t size() const = 0; @@ -107,6 +118,12 @@ class LookupInterface : public ResourceBase { virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, const Tensor& values); + // Check format of the key tensor for the Remove function. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + virtual Status CheckKeyTensorForRemove(const Tensor& keys); + // Check the arguments of a find operation. Returns OK if all the following // requirements are satisfied, otherwise it returns InvalidArgument: // - DataType of the tensor keys equals to the table key_dtype diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 424fe5df3c..a14d4967a5 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -51,6 +51,12 @@ class InitializableLookupTable : public LookupInterface { "Insert not supported by InitializableLookupTable implementations"); } + // Returns errors::Unimplemented. + Status Remove(OpKernelContext* ctx, const Tensor& keys) final { + return errors::Unimplemented( + "Remove not supported by InitializableLookupTable implementations"); + } + Status ExportValues(OpKernelContext* context) override { return errors::Unimplemented( "ExportValues not supported by InitializableLookupTable " diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index a495758861..0bc1ea77d6 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -89,6 +89,16 @@ class MutableHashTableOfScalars final : public LookupInterface { return DoInsert(false, keys, values); } + Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + const auto key_values = keys.flat<K>(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface { return DoInsert(false, keys, values); } + Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + const auto key_values = keys.flat<K>(); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + table_.erase(SubtleMustCopyIfIntegral(key_values(i))); + } + return Status::OK(); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override { return DoInsert(true, keys, values); @@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface { empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}), 0); + const Tensor* deleted_key_input; + OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input)); + OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()), + errors::InvalidArgument( + "Empty and deleted keys must have same shape, got shapes: ", + key_shape_.DebugString(), " and ", + deleted_key_input->shape().DebugString())); + deleted_key_ = PersistentTensor(*deleted_key_input); + deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>( + {1, key_shape_.num_elements()}), + 0); + + if (empty_key_hash_ == deleted_key_hash_) { + const int64 key_size = key_shape_.num_elements(); + const auto empty_key_matrix = + empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + OP_REQUIRES( + ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0), + errors::InvalidArgument("Empty and deleted keys cannot be equal")); + } + int64 initial_num_buckets; OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets", &initial_num_buckets)); @@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix<V>(); const auto empty_key_matrix = empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_matrix = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); const int64 bit_mask = num_buckets_ - 1; // TODO(andreasst): parallelize using work_sharder for (int64 i = 0; i < num_elements; ++i) { @@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface { return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface { return DoInsert(ctx, key, value, false); } + Status Remove(OpKernelContext* ctx, const Tensor& key) override + LOCKS_EXCLUDED(mu_) { + if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { + TensorShape expected_shape({key.dim_size(0)}); + expected_shape.AppendShape(key_shape_); + return errors::InvalidArgument("Expected key shape ", + expected_shape.DebugString(), " got ", + key.shape().DebugString()); + } + mutex_lock l(mu_); + return DoRemove(ctx, key); + } + Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) override LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); num_buckets_ = keys.dim_size(0); key_buckets_ = PersistentTensor(keys); value_buckets_ = PersistentTensor(values); - // Count the number of keys that are not the empty_key. This requires - // iterating through the whole table but that is OK as we only execute it - // during checkpoint restore. + // Count the number of keys that are not the empty_key or deleted_key. + // This requires iterating through the whole table but that is OK as we + // only execute it during checkpoint restore. num_entries_ = 0; const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped<K, 2>( {1, key_shape_.num_elements()}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>( + {1, key_shape_.num_elements()}); const auto key_buckets_tensor = key_buckets_.AccessTensor(ctx)->template matrix<K>(); for (int64 i = 0; i < num_buckets_; ++i) { - if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) { + if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) && + !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) { ++num_entries_; } } @@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, - bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + bool ignore_empty_and_deleted_key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); @@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_.AccessTensor(ctx)->template matrix<V>(); const auto empty_key_tensor = empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); const int64 bit_mask = num_buckets_ - 1; for (int64 i = 0; i < num_elements; ++i) { const uint64 key_hash = HashKey(key_matrix, i); if (empty_key_hash_ == key_hash && IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { - if (ignore_empty_key) { + if (ignore_empty_and_deleted_key) { continue; } return errors::InvalidArgument( "Using the empty_key as a table key is not allowed"); } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + if (ignore_empty_and_deleted_key) { + continue; + } + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } int64 bucket_index = key_hash & bit_mask; int64 num_probes = 0; while (true) { @@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface { } break; } - if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) || + IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor, + 0)) { ++num_entries_; for (int64 j = 0; j < key_size; ++j) { key_buckets_matrix(bucket_index, j) = @@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface { return Status::OK(); } + Status DoRemove(OpKernelContext* ctx, const Tensor& key) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const int64 num_elements = key.dim_size(0); + const int64 key_size = key_shape_.num_elements(); + const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); + + auto key_buckets_matrix = + key_buckets_.AccessTensor(ctx)->template matrix<K>(); + const auto empty_key_tensor = + empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_tensor = + deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); + const auto deleted_key_flat = + deleted_key_.AccessTensor(ctx)->template flat<K>(); + const int64 bit_mask = num_buckets_ - 1; + for (int64 i = 0; i < num_elements; ++i) { + const uint64 key_hash = HashKey(key_matrix, i); + if (empty_key_hash_ == key_hash && + IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the empty_key as a table key is not allowed"); + } + if (deleted_key_hash_ == key_hash && + IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { + return errors::InvalidArgument( + "Using the deleted_key as a table key is not allowed"); + } + int64 bucket_index = key_hash & bit_mask; + int64 num_probes = 0; + while (true) { + if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { + --num_entries_; + for (int64 j = 0; j < key_size; ++j) { + key_buckets_matrix(bucket_index, j) = + SubtleMustCopyIfIntegral(deleted_key_flat(j)); + } + break; + } + if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { + break; + } + ++num_probes; + bucket_index = + (bucket_index + num_probes) & bit_mask; // quadratic probing + if (num_probes >= num_buckets_) { + return errors::Internal( + "Internal error in MutableDenseHashTable remove"); + } + } + } + return Status::OK(); + } + Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (new_num_buckets < 4 || @@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface { PersistentTensor value_buckets_ GUARDED_BY(mu_); PersistentTensor empty_key_; uint64 empty_key_hash_; -}; + PersistentTensor deleted_key_; + uint64 deleted_key_hash_; +}; // namespace lookup } // namespace lookup @@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), LookupTableInsertOp); +// Table remove op. +class LookupTableRemoveOp : public OpKernel { + public: + explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor& key = ctx->input(1); + OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->Remove(ctx, key)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU), + LookupTableRemoveOp); + // Op that returns the size of the given table. class LookupTableSizeOp : public OpKernel { public: diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cfb1055d3c..415e15b720 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -30321,6 +30321,22 @@ op { is_stateful: true } op { + name: "LookupTableRemoveV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tin" + } + attr { + name: "Tin" + type: "type" + } + is_stateful: true +} +op { name: "LookupTableSize" input_arg { name: "table_handle" @@ -36706,6 +36722,10 @@ op { name: "empty_key" type_attr: "key_dtype" } + input_arg { + name: "deleted_key" + type_attr: "key_dtype" + } output_arg { name: "table_handle" type: DT_RESOURCE diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 72a77be70d..a0987cd982 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -214,6 +214,19 @@ REGISTER_OP("LookupTableInsertV2") return Status::OK(); }); +REGISTER_OP("LookupTableRemoveV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Attr("Tin: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); + + // TODO(turboale): Validate keys shape. + return Status::OK(); + }); + REGISTER_OP("LookupTableSize") .Input("table_handle: Ref(string)") .Output("size: int64") @@ -407,6 +420,7 @@ REGISTER_OP("MutableDenseHashTable") REGISTER_OP("MutableDenseHashTableV2") .Input("empty_key: key_dtype") + .Input("deleted_key: key_dtype") .Output("table_handle: resource") .Attr("container: string = ''") .Attr("shared_name: string = ''") |