diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_op.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 73 |
1 files changed, 51 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 07e754a6ef..2e8d9c623c 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -341,7 +341,7 @@ class MutableDenseHashTable final : public LookupInterface { Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, const Tensor& default_value) override LOCKS_EXCLUDED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 key_size = key_shape_.num_elements(); const int64 value_size = value_shape_.num_elements(); if (key.NumElements() != num_elements * key_size) { @@ -403,8 +403,9 @@ class MutableDenseHashTable final : public LookupInterface { Status Insert(OpKernelContext* ctx, const Tensor& key, const Tensor& value) override LOCKS_EXCLUDED(mu_) { - if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { - TensorShape expected_shape({key.dim_size(0)}); + const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0); + if (key.NumElements() != batch_size * key_shape_.num_elements()) { + TensorShape expected_shape({batch_size}); expected_shape.AppendShape(key_shape_); return errors::InvalidArgument("Expected key shape ", expected_shape.DebugString(), " got ", @@ -415,7 +416,7 @@ class MutableDenseHashTable final : public LookupInterface { // rather than updates. That means we may grow the table even though we // don't need to. As long as the number of keys inserted in one call is // small compared to the size of the map, the impact of this is minimal. - const int64 pending_num_entries = num_entries_ + key.dim_size(0); + const int64 pending_num_entries = num_entries_ + batch_size; if (pending_num_entries > num_buckets_ * max_load_factor_) { int64 new_num_buckets = num_buckets_; do { @@ -500,7 +501,7 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); @@ -812,17 +813,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU), LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ value_dtype>) +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, string); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); +REGISTER_KERNEL(int64, string); +REGISTER_KERNEL(string, bool); REGISTER_KERNEL(string, double); REGISTER_KERNEL(string, float); REGISTER_KERNEL(string, int32); REGISTER_KERNEL(string, int64); -REGISTER_KERNEL(int64, string); -REGISTER_KERNEL(int64, int64); -REGISTER_KERNEL(int64, float); REGISTER_KERNEL(string, string); -REGISTER_KERNEL(string, bool); -REGISTER_KERNEL(int32, int32); -REGISTER_KERNEL(int32, string); #undef REGISTER_KERNEL @@ -843,12 +848,20 @@ REGISTER_KERNEL(int32, string); LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, int64); -REGISTER_KERNEL(int64, string); -REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int64, double); REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); +REGISTER_KERNEL(int64, string); REGISTER_KERNEL(int64, Variant); +REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL @@ -869,10 +882,19 @@ REGISTER_KERNEL(int64, Variant); LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, int64); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, string); REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL @@ -893,13 +915,20 @@ REGISTER_KERNEL(string, bool); LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ key_dtype, value_dtype>) -REGISTER_KERNEL(int64, int64); -REGISTER_KERNEL(int64, float); -REGISTER_KERNEL(int64, double); -REGISTER_KERNEL(string, float); -REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); REGISTER_KERNEL(int64, bool); +REGISTER_KERNEL(int64, double); +REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, int32); +REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, Variant); +REGISTER_KERNEL(string, bool); +REGISTER_KERNEL(string, double); +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int32); +REGISTER_KERNEL(string, int64); #undef REGISTER_KERNEL |