diff options
author | RJ Ryan <rjryan@google.com> | 2018-01-14 22:32:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-14 22:36:12 -0800 |
commit | 4e1c807999234b45a33c7fb3dabb0c34ab6ac185 (patch) | |
tree | af1cdd71ca60cb0f007d4dc50985ae180effb9f5 /tensorflow/contrib/lookup | |
parent | a79e97be8460ce3e1a7de2ddbc78b76151e0035a (diff) |
Support ref types as arguments to all instances of tf.contrib.lookup.LookupInterface.lookup.
PiperOrigin-RevId: 181928781
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 16 |
2 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 66caa6a2e5..a430dac4ec 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -399,7 +399,7 @@ class MutableHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) @@ -600,7 +600,7 @@ class MutableDenseHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index f0499010d4..65aaaf85c3 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -187,6 +187,11 @@ class HashTableOpTest(test.TestCase): lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() + # Ref types do not produce a lookup signature mismatch. + input_string_ref = variables.Variable("brain") + variables.global_variables_initializer().run() + self.assertEqual(0, table.lookup(input_string_ref).eval()) + input_string = constant_op.constant([1, 2, 3], dtypes.int64) with self.assertRaises(TypeError): table.lookup(input_string) @@ -629,6 +634,17 @@ class MutableHashTableOpTest(test.TestCase): table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) + input_string_ref = variables.Variable("brain") + input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) + variables.global_variables_initializer().run() + + # Ref types do not produce an insert signature mismatch. + table.insert(input_string_ref, input_int64_ref).run() + self.assertAllEqual(3, table.size().eval()) + + # Ref types do not produce a lookup signature mismatch. + self.assertEqual(-1, table.lookup(input_string_ref).eval()) + # lookup with keys of the wrong type input_string = constant_op.constant([1, 2, 3], dtypes.int64) with self.assertRaises(TypeError): |