diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_op.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 99 |
1 files changed, 63 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 07e754a6ef..a495758861 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface { MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface { const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); @@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); Tensor* keys; @@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface { } private: - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; std::unordered_map<K, V> table_ GUARDED_BY(mu_); }; @@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface { auto value_values = value->flat_inner_dims<V, 2>(); int64 value_dim = value_shape_.dim_size(0); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { ValueArray* value_vec = gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); @@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); int64 value_dim = value_shape_.dim_size(0); @@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface { private: TensorShape value_shape_; - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; typedef gtl::InlinedVector<V, 4> ValueArray; std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); @@ -335,13 +333,13 @@ class MutableDenseHashTable final : public LookupInterface { } size_t size() const override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return num_entries_; } Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, const Tensor& default_value) override LOCKS_EXCLUDED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 key_size = key_shape_.num_elements(); const int64 value_size = value_shape_.num_elements(); if (key.NumElements() != num_elements * key_size) { @@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface { auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); const auto default_flat = default_value.flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); const auto key_buckets_matrix = key_buckets_.AccessTensor(ctx)->template matrix<K>(); const auto value_buckets_matrix = @@ -403,8 +401,9 @@ class MutableDenseHashTable final : public LookupInterface { Status Insert(OpKernelContext* ctx, const Tensor& key, const Tensor& value) override LOCKS_EXCLUDED(mu_) { - if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { - TensorShape expected_shape({key.dim_size(0)}); + const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0); + if (key.NumElements() != batch_size * key_shape_.num_elements()) { + TensorShape expected_shape({batch_size}); expected_shape.AppendShape(key_shape_); return errors::InvalidArgument("Expected key shape ", expected_shape.DebugString(), " got ", @@ -415,7 +414,7 @@ class MutableDenseHashTable final : public LookupInterface { // rather than updates. That means we may grow the table even though we // don't need to. As long as the number of keys inserted in one call is // small compared to the size of the map, the impact of this is minimal. - const int64 pending_num_entries = num_entries_ + key.dim_size(0); + const int64 pending_num_entries = num_entries_ + batch_size; if (pending_num_entries > num_buckets_ * max_load_factor_) { int64 new_num_buckets = num_buckets_; do { @@ -450,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx); Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor)); @@ -492,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface { TensorShape value_shape() const override { return value_shape_; } int64 MemoryUsed() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); } @@ -500,7 +499,7 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); @@ -812,17 +811,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU), LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ value_dtype>) +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, string); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); +REGISTER_KERNEL(int64, string); +REGISTER_KERNEL(string, bool); REGISTER_KERNEL(string, double); REGISTER_KERNEL(string, float); REGISTER_KERNEL(string, int32); REGISTER_KERNEL(string, int64); -REGISTER_KERNEL(int64, string); -REGISTER_KERNEL(int64, int64); -REGISTER_KERNEL(int64, float); REGISTER_KERNEL(string, string); -REGISTER_KERNEL(string, bool); -REGISTER_KERNEL(int32, int32); -REGISTER_KERNEL(int32, string); #undef REGISTER_KERNEL @@ -843,12 +846,20 @@ REGISTER_KERNEL(int32, string); LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, int64); -REGISTER_KERNEL(int64, string); -REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int64, double); REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); +REGISTER_KERNEL(int64, string); REGISTER_KERNEL(int64, Variant); +REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL @@ -869,10 +880,19 @@ REGISTER_KERNEL(int64, Variant); LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, int64); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, string); REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL @@ -893,13 +913,20 @@ REGISTER_KERNEL(string, bool); LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(int64, int64); -REGISTER_KERNEL(int64, float); -REGISTER_KERNEL(int64, double); -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); REGISTER_KERNEL(int64, bool); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, Variant); +REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL |