aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-01-14 22:32:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-14 22:36:12 -0800
commit4e1c807999234b45a33c7fb3dabb0c34ab6ac185 (patch)
treeaf1cdd71ca60cb0f007d4dc50985ae180effb9f5
parenta79e97be8460ce3e1a7de2ddbc78b76151e0035a (diff)
Support ref types as arguments to all instances of tf.contrib.lookup.LookupInterface.lookup.
PiperOrigin-RevId: 181928781
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py4
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py16
-rw-r--r--tensorflow/python/ops/lookup_ops.py8
4 files changed, 23 insertions, 7 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
index 7e214905b1..ec726bbed4 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
@@ -102,7 +102,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface):
keys.get_shape())
def lookup(self, keys, name=None):
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
(self._key_dtype, keys.dtype))
self._check_keys(keys)
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 66caa6a2e5..a430dac4ec 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -399,7 +399,7 @@ class MutableHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
@@ -600,7 +600,7 @@ class MutableDenseHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index f0499010d4..65aaaf85c3 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -187,6 +187,11 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
+ # Ref types do not produce a lookup signature mismatch.
+ input_string_ref = variables.Variable("brain")
+ variables.global_variables_initializer().run()
+ self.assertEqual(0, table.lookup(input_string_ref).eval())
+
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
@@ -629,6 +634,17 @@ class MutableHashTableOpTest(test.TestCase):
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
+ input_string_ref = variables.Variable("brain")
+ input_int64_ref = variables.Variable(-1, dtype=dtypes.int64)
+ variables.global_variables_initializer().run()
+
+ # Ref types do not produce an insert signature mismatch.
+ table.insert(input_string_ref, input_int64_ref).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ # Ref types do not produce a lookup signature mismatch.
+ self.assertEqual(-1, table.lookup(input_string_ref).eval())
+
# lookup with keys of the wrong type
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index d68b32cc6b..b2ad4ad2e8 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -84,10 +84,10 @@ def _check_table_dtypes(table, key_dtype, value_dtype):
TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
types.
"""
- if key_dtype != table.key_dtype:
+ if key_dtype.base_dtype != table.key_dtype:
raise TypeError("Invalid key dtype, expected %s but got %s." %
(table.key_dtype, key_dtype))
- if value_dtype != table.value_dtype:
+ if value_dtype.base_dtype != table.value_dtype:
raise TypeError("Invalid value dtype, expected %s but got %s." %
(table.value_dtype, value_dtype))
@@ -217,7 +217,7 @@ class InitializableLookupTableBase(LookupInterface):
if isinstance(keys, sparse_tensor.SparseTensor):
key_tensor = keys.values
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
@@ -849,7 +849,7 @@ class IdTableWithHashBuckets(LookupInterface):
Raises:
TypeError: when `keys` doesn't match the table key data type.
"""
- if keys.dtype != self._key_dtype:
+ if keys.dtype.base_dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
values = keys