diff options
4 files changed, 23 insertions, 7 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 7e214905b1..ec726bbed4 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -102,7 +102,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): keys.get_shape()) def lookup(self, keys, name=None): - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) self._check_keys(keys) diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 66caa6a2e5..a430dac4ec 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -399,7 +399,7 @@ class MutableHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) @@ -600,7 +600,7 @@ class MutableDenseHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index f0499010d4..65aaaf85c3 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -187,6 +187,11 @@ class HashTableOpTest(test.TestCase): lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() + # Ref types do not produce a lookup signature mismatch. + input_string_ref = variables.Variable("brain") + variables.global_variables_initializer().run() + self.assertEqual(0, table.lookup(input_string_ref).eval()) + input_string = constant_op.constant([1, 2, 3], dtypes.int64) with self.assertRaises(TypeError): table.lookup(input_string) @@ -629,6 +634,17 @@ class MutableHashTableOpTest(test.TestCase): table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) + input_string_ref = variables.Variable("brain") + input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) + variables.global_variables_initializer().run() + + # Ref types do not produce an insert signature mismatch. + table.insert(input_string_ref, input_int64_ref).run() + self.assertAllEqual(3, table.size().eval()) + + # Ref types do not produce a lookup signature mismatch. + self.assertEqual(-1, table.lookup(input_string_ref).eval()) + # lookup with keys of the wrong type input_string = constant_op.constant([1, 2, 3], dtypes.int64) with self.assertRaises(TypeError): diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index d68b32cc6b..b2ad4ad2e8 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -84,10 +84,10 @@ def _check_table_dtypes(table, key_dtype, value_dtype): TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data types. """ - if key_dtype != table.key_dtype: + if key_dtype.base_dtype != table.key_dtype: raise TypeError("Invalid key dtype, expected %s but got %s." % (table.key_dtype, key_dtype)) - if value_dtype != table.value_dtype: + if value_dtype.base_dtype != table.value_dtype: raise TypeError("Invalid value dtype, expected %s but got %s." % (table.value_dtype, value_dtype)) @@ -217,7 +217,7 @@ class InitializableLookupTableBase(LookupInterface): if isinstance(keys, sparse_tensor.SparseTensor): key_tensor = keys.values - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) @@ -849,7 +849,7 @@ class IdTableWithHashBuckets(LookupInterface): Raises: TypeError: when `keys` doesn't match the table key data type. """ - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) values = keys |