aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_op.cc')
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc99
1 files changed, 63 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 07e754a6ef..a495758861 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
Tensor* keys;
@@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
private:
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
std::unordered_map<K, V> table_ GUARDED_BY(mu_);
};
@@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
int64 value_dim = value_shape_.dim_size(0);
@@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface {
private:
TensorShape value_shape_;
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,13 +333,13 @@ class MutableDenseHashTable final : public LookupInterface {
}
size_t size() const override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return num_entries_;
}
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 key_size = key_shape_.num_elements();
const int64 value_size = value_shape_.num_elements();
if (key.NumElements() != num_elements * key_size) {
@@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface {
auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
const auto default_flat = default_value.flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
const auto key_buckets_matrix =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
const auto value_buckets_matrix =
@@ -403,8 +401,9 @@ class MutableDenseHashTable final : public LookupInterface {
Status Insert(OpKernelContext* ctx, const Tensor& key,
const Tensor& value) override LOCKS_EXCLUDED(mu_) {
- if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
- TensorShape expected_shape({key.dim_size(0)});
+ const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
+ if (key.NumElements() != batch_size * key_shape_.num_elements()) {
+ TensorShape expected_shape({batch_size});
expected_shape.AppendShape(key_shape_);
return errors::InvalidArgument("Expected key shape ",
expected_shape.DebugString(), " got ",
@@ -415,7 +414,7 @@ class MutableDenseHashTable final : public LookupInterface {
// rather than updates. That means we may grow the table even though we
// don't need to. As long as the number of keys inserted in one call is
// small compared to the size of the map, the impact of this is minimal.
- const int64 pending_num_entries = num_entries_ + key.dim_size(0);
+ const int64 pending_num_entries = num_entries_ + batch_size;
if (pending_num_entries > num_buckets_ * max_load_factor_) {
int64 new_num_buckets = num_buckets_;
do {
@@ -450,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -492,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
@@ -500,7 +499,7 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
@@ -812,17 +811,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int32, string);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
+REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(string, double);
REGISTER_KERNEL(string, float);
REGISTER_KERNEL(string, int32);
REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(string, string);
-REGISTER_KERNEL(string, bool);
-REGISTER_KERNEL(int32, int32);
-REGISTER_KERNEL(int32, string);
#undef REGISTER_KERNEL
@@ -843,12 +846,20 @@ REGISTER_KERNEL(int32, string);
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -869,10 +880,19 @@ REGISTER_KERNEL(int64, Variant);
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -893,13 +913,20 @@ REGISTER_KERNEL(string, bool);
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
-REGISTER_KERNEL(int64, double);
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
REGISTER_KERNEL(int64, bool);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL