aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_op.cc')
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc73
1 files changed, 51 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 07e754a6ef..2e8d9c623c 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -341,7 +341,7 @@ class MutableDenseHashTable final : public LookupInterface {
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 key_size = key_shape_.num_elements();
const int64 value_size = value_shape_.num_elements();
if (key.NumElements() != num_elements * key_size) {
@@ -403,8 +403,9 @@ class MutableDenseHashTable final : public LookupInterface {
Status Insert(OpKernelContext* ctx, const Tensor& key,
const Tensor& value) override LOCKS_EXCLUDED(mu_) {
- if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
- TensorShape expected_shape({key.dim_size(0)});
+ const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
+ if (key.NumElements() != batch_size * key_shape_.num_elements()) {
+ TensorShape expected_shape({batch_size});
expected_shape.AppendShape(key_shape_);
return errors::InvalidArgument("Expected key shape ",
expected_shape.DebugString(), " got ",
@@ -415,7 +416,7 @@ class MutableDenseHashTable final : public LookupInterface {
// rather than updates. That means we may grow the table even though we
// don't need to. As long as the number of keys inserted in one call is
// small compared to the size of the map, the impact of this is minimal.
- const int64 pending_num_entries = num_entries_ + key.dim_size(0);
+ const int64 pending_num_entries = num_entries_ + batch_size;
if (pending_num_entries > num_buckets_ * max_load_factor_) {
int64 new_num_buckets = num_buckets_;
do {
@@ -500,7 +501,7 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
@@ -812,17 +813,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int32, string);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
+REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(string, double);
REGISTER_KERNEL(string, float);
REGISTER_KERNEL(string, int32);
REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(string, string);
-REGISTER_KERNEL(string, bool);
-REGISTER_KERNEL(int32, int32);
-REGISTER_KERNEL(int32, string);
#undef REGISTER_KERNEL
@@ -843,12 +848,20 @@ REGISTER_KERNEL(int32, string);
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -869,10 +882,19 @@ REGISTER_KERNEL(int64, Variant);
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -893,13 +915,20 @@ REGISTER_KERNEL(string, bool);
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
-REGISTER_KERNEL(int64, double);
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
REGISTER_KERNEL(int64, bool);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL