aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 16:23:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 16:39:23 -0700
commit6c391166b8b6ba43d2b0151e6fb9cf14864131a2 (patch)
treec8c3c9eadade00f1a4e6cec2024e2a15bfd0b948 /tensorflow/core
parent2f5ebc0ea5e6d500ea8cd925234c569d6b32fd4e (diff)
Add 'remove' operation to MutableHashTable and MutableDenseHashTable.
PiperOrigin-RevId: 216443201
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt24
-rw-r--r--tensorflow/core/framework/lookup_interface.cc8
-rw-r--r--tensorflow/core/framework/lookup_interface.h17
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h6
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc184
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt20
-rw-r--r--tensorflow/core/ops/lookup_ops.cc14
7 files changed, 265 insertions, 8 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt
new file mode 100644
index 0000000000..333fe6f4b2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LookupTableRemoveV2.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LookupTableRemoveV2"
+ visibility: HIDDEN
+ endpoint {
+ name: "LookupTableRemove"
+ }
+ in_arg {
+ name: "table_handle"
+ description: <<END
+Handle to the table.
+END
+ }
+ in_arg {
+ name: "keys"
+ description: <<END
+Any shape. Keys of the elements to remove.
+END
+ }
+ summary: "Removes keys and its associated values from a table."
+ description: <<END
+The tensor `keys` must of the same type as the keys of the table. Keys not
+already in the table are silently ignored.
+END
+}
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
index bf3204ea6e..117adbf65c 100644
--- a/tensorflow/core/framework/lookup_interface.cc
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -71,6 +71,14 @@ Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys,
return CheckKeyAndValueTensorsHelper(keys, values);
}
+Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) {
+ if (keys.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", keys.dtype());
+ }
+ return CheckKeyShape(keys.shape());
+}
+
Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 0622dd06cb..d33945fd1b 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -64,6 +64,17 @@ class LookupInterface : public ResourceBase {
virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) = 0;
+ // Removes elements from the table.
+ // This method is only implemented in mutable tables that can be updated over
+ // the execution of the graph. It returns Status::NotImplemented for read-only
+ // tables that are initialized once before they can be looked up.
+
+ // Returns the following statuses:
+ // - OK: when the remove finishes successfully.
+ // - InvalidArgument: if any of the preconditions on the lookup key fails.
+ // - Unimplemented: if the table does not support removals.
+ virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0;
+
// Returns the number of elements in the table.
virtual size_t size() const = 0;
@@ -107,6 +118,12 @@ class LookupInterface : public ResourceBase {
virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
const Tensor& values);
+ // Check format of the key tensor for the Remove function.
+ // Returns OK if all the following requirements are satisfied, otherwise it
+ // returns InvalidArgument:
+ // - DataType of the tensor keys equals to the table key_dtype
+ virtual Status CheckKeyTensorForRemove(const Tensor& keys);
+
// Check the arguments of a find operation. Returns OK if all the following
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 424fe5df3c..a14d4967a5 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -51,6 +51,12 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
+ // Returns errors::Unimplemented.
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
+ return errors::Unimplemented(
+ "Remove not supported by InitializableLookupTable implementations");
+ }
+
Status ExportValues(OpKernelContext* context) override {
return errors::Unimplemented(
"ExportValues not supported by InitializableLookupTable "
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index a495758861..0bc1ea77d6 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -89,6 +89,16 @@ class MutableHashTableOfScalars final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -212,6 +222,16 @@ class MutableHashTableOfTensors final : public LookupInterface {
return DoInsert(false, keys, values);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
+ const auto key_values = keys.flat<K>();
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
+ }
+ return Status::OK();
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override {
return DoInsert(true, keys, values);
@@ -326,6 +346,29 @@ class MutableDenseHashTable final : public LookupInterface {
empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
0);
+ const Tensor* deleted_key_input;
+ OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
+ OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
+ errors::InvalidArgument(
+ "Empty and deleted keys must have same shape, got shapes: ",
+ key_shape_.DebugString(), " and ",
+ deleted_key_input->shape().DebugString()));
+ deleted_key_ = PersistentTensor(*deleted_key_input);
+ deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
+ {1, key_shape_.num_elements()}),
+ 0);
+
+ if (empty_key_hash_ == deleted_key_hash_) {
+ const int64 key_size = key_shape_.num_elements();
+ const auto empty_key_matrix =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ OP_REQUIRES(
+ ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
+ errors::InvalidArgument("Empty and deleted keys cannot be equal"));
+ }
+
int64 initial_num_buckets;
OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
&initial_num_buckets));
@@ -360,6 +403,8 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_matrix =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_matrix =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
// TODO(andreasst): parallelize using work_sharder
for (int64 i = 0; i < num_elements; ++i) {
@@ -369,6 +414,11 @@ class MutableDenseHashTable final : public LookupInterface {
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -425,23 +475,40 @@ class MutableDenseHashTable final : public LookupInterface {
return DoInsert(ctx, key, value, false);
}
+ Status Remove(OpKernelContext* ctx, const Tensor& key) override
+ LOCKS_EXCLUDED(mu_) {
+ if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
+ TensorShape expected_shape({key.dim_size(0)});
+ expected_shape.AppendShape(key_shape_);
+ return errors::InvalidArgument("Expected key shape ",
+ expected_shape.DebugString(), " got ",
+ key.shape().DebugString());
+ }
+ mutex_lock l(mu_);
+ return DoRemove(ctx, key);
+ }
+
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
num_buckets_ = keys.dim_size(0);
key_buckets_ = PersistentTensor(keys);
value_buckets_ = PersistentTensor(values);
- // Count the number of keys that are not the empty_key. This requires
- // iterating through the whole table but that is OK as we only execute it
- // during checkpoint restore.
+ // Count the number of keys that are not the empty_key or deleted_key.
+ // This requires iterating through the whole table but that is OK as we
+ // only execute it during checkpoint restore.
num_entries_ = 0;
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
{1, key_shape_.num_elements()});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>(
+ {1, key_shape_.num_elements()});
const auto key_buckets_tensor =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
for (int64 i = 0; i < num_buckets_; ++i) {
- if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) {
+ if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
+ !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
++num_entries_;
}
}
@@ -498,7 +565,8 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
- bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ bool ignore_empty_and_deleted_key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
@@ -511,17 +579,27 @@ class MutableDenseHashTable final : public LookupInterface {
value_buckets_.AccessTensor(ctx)->template matrix<V>();
const auto empty_key_tensor =
empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
const int64 bit_mask = num_buckets_ - 1;
for (int64 i = 0; i < num_elements; ++i) {
const uint64 key_hash = HashKey(key_matrix, i);
if (empty_key_hash_ == key_hash &&
IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
- if (ignore_empty_key) {
+ if (ignore_empty_and_deleted_key) {
continue;
}
return errors::InvalidArgument(
"Using the empty_key as a table key is not allowed");
}
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ if (ignore_empty_and_deleted_key) {
+ continue;
+ }
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
int64 bucket_index = key_hash & bit_mask;
int64 num_probes = 0;
while (true) {
@@ -532,7 +610,9 @@ class MutableDenseHashTable final : public LookupInterface {
}
break;
}
- if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
+ IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
+ 0)) {
++num_entries_;
for (int64 j = 0; j < key_size; ++j) {
key_buckets_matrix(bucket_index, j) =
@@ -556,6 +636,59 @@ class MutableDenseHashTable final : public LookupInterface {
return Status::OK();
}
+ Status DoRemove(OpKernelContext* ctx, const Tensor& key)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const int64 num_elements = key.dim_size(0);
+ const int64 key_size = key_shape_.num_elements();
+ const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
+
+ auto key_buckets_matrix =
+ key_buckets_.AccessTensor(ctx)->template matrix<K>();
+ const auto empty_key_tensor =
+ empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_tensor =
+ deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
+ const auto deleted_key_flat =
+ deleted_key_.AccessTensor(ctx)->template flat<K>();
+ const int64 bit_mask = num_buckets_ - 1;
+ for (int64 i = 0; i < num_elements; ++i) {
+ const uint64 key_hash = HashKey(key_matrix, i);
+ if (empty_key_hash_ == key_hash &&
+ IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the empty_key as a table key is not allowed");
+ }
+ if (deleted_key_hash_ == key_hash &&
+ IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
+ return errors::InvalidArgument(
+ "Using the deleted_key as a table key is not allowed");
+ }
+ int64 bucket_index = key_hash & bit_mask;
+ int64 num_probes = 0;
+ while (true) {
+ if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
+ --num_entries_;
+ for (int64 j = 0; j < key_size; ++j) {
+ key_buckets_matrix(bucket_index, j) =
+ SubtleMustCopyIfIntegral(deleted_key_flat(j));
+ }
+ break;
+ }
+ if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
+ break;
+ }
+ ++num_probes;
+ bucket_index =
+ (bucket_index + num_probes) & bit_mask; // quadratic probing
+ if (num_probes >= num_buckets_) {
+ return errors::Internal(
+ "Internal error in MutableDenseHashTable remove");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (new_num_buckets < 4 ||
@@ -639,7 +772,9 @@ class MutableDenseHashTable final : public LookupInterface {
PersistentTensor value_buckets_ GUARDED_BY(mu_);
PersistentTensor empty_key_;
uint64 empty_key_hash_;
-};
+ PersistentTensor deleted_key_;
+ uint64 deleted_key_hash_;
+}; // namespace lookup
} // namespace lookup
@@ -717,6 +852,39 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
LookupTableInsertOp);
+// Table remove op.
+class LookupTableRemoveOp : public OpKernel {
+ public:
+ explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
+
+ const Tensor& key = ctx->input(1);
+ OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
+
+ int64 memory_used_before = 0;
+ if (ctx->track_allocations()) {
+ memory_used_before = table->MemoryUsed();
+ }
+ OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
+ if (ctx->track_allocations()) {
+ ctx->record_persistent_memory_allocation(table->MemoryUsed() -
+ memory_used_before);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
+ LookupTableRemoveOp);
+
// Op that returns the size of the given table.
class LookupTableSizeOp : public OpKernel {
public:
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index cfb1055d3c..415e15b720 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -30321,6 +30321,22 @@ op {
is_stateful: true
}
op {
+ name: "LookupTableRemoveV2"
+ input_arg {
+ name: "table_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "keys"
+ type_attr: "Tin"
+ }
+ attr {
+ name: "Tin"
+ type: "type"
+ }
+ is_stateful: true
+}
+op {
name: "LookupTableSize"
input_arg {
name: "table_handle"
@@ -36706,6 +36722,10 @@ op {
name: "empty_key"
type_attr: "key_dtype"
}
+ input_arg {
+ name: "deleted_key"
+ type_attr: "key_dtype"
+ }
output_arg {
name: "table_handle"
type: DT_RESOURCE
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 72a77be70d..a0987cd982 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -214,6 +214,19 @@ REGISTER_OP("LookupTableInsertV2")
return Status::OK();
});
+REGISTER_OP("LookupTableRemoveV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Attr("Tin: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
+
+ // TODO(turboale): Validate keys shape.
+ return Status::OK();
+ });
+
REGISTER_OP("LookupTableSize")
.Input("table_handle: Ref(string)")
.Output("size: int64")
@@ -407,6 +420,7 @@ REGISTER_OP("MutableDenseHashTable")
REGISTER_OP("MutableDenseHashTableV2")
.Input("empty_key: key_dtype")
+ .Input("deleted_key: key_dtype")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")