aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-01-14 22:32:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-14 22:36:12 -0800
commit4e1c807999234b45a33c7fb3dabb0c34ab6ac185 (patch)
treeaf1cdd71ca60cb0f007d4dc50985ae180effb9f5 /tensorflow/contrib/lookup
parenta79e97be8460ce3e1a7de2ddbc78b76151e0035a (diff)
Support ref types as arguments to all instances of tf.contrib.lookup.LookupInterface.lookup.
PiperOrigin-RevId: 181928781
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py4
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py16
2 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 66caa6a2e5..a430dac4ec 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -399,7 +399,7 @@ class MutableHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
@@ -600,7 +600,7 @@ class MutableDenseHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index f0499010d4..65aaaf85c3 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -187,6 +187,11 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
+ # Ref types do not produce a lookup signature mismatch.
+ input_string_ref = variables.Variable("brain")
+ variables.global_variables_initializer().run()
+ self.assertEqual(0, table.lookup(input_string_ref).eval())
+
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
@@ -629,6 +634,17 @@ class MutableHashTableOpTest(test.TestCase):
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
+ input_string_ref = variables.Variable("brain")
+ input_int64_ref = variables.Variable(-1, dtype=dtypes.int64)
+ variables.global_variables_initializer().run()
+
+ # Ref types do not produce an insert signature mismatch.
+ table.insert(input_string_ref, input_int64_ref).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ # Ref types do not produce a lookup signature mismatch.
+ self.assertEqual(-1, table.lookup(input_string_ref).eval())
+
# lookup with keys of the wrong type
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):