aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-08-13 13:03:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-13 13:07:41 -0700
commit35535e6313c7c35a851466efd67be0ec1df14c9e (patch)
tree199000b261fdc9b3331650580e7883e768a36c07 /tensorflow/contrib/lookup
parent959f075558b33674c201367aef4bfc9c2dc116c4 (diff)
[tf.contrib.lookup] Clean up shape inference for lookup ops.
More of the shape inference can be done in C++-land, which may help grappler do its thing. Also fix a bug where keys.dim_size(0) was being requested even when keys.dims() == 0 [this should probably lead to DCHECK failure, but doesn't seem to]. PiperOrigin-RevId: 208529368
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py33
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py11
2 files changed, 15 insertions, 29 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 4942d94176..8c0bfefb30 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import lookup_ops
# pylint: disable=unused-import
@@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
-
- values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface):
with ops.colocate_with(self._table_ref):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
-
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):
@@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface):
ValueError: If checkpoint is True and no name was specified.
"""
self._default_value = ops.convert_to_tensor(
- default_value, dtype=value_dtype)
+ default_value, dtype=value_dtype, name="default_value")
self._value_shape = self._default_value.get_shape()
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
use_node_name_sharing = checkpoint and shared_name is None
- empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
+ empty_key = ops.convert_to_tensor(
+ empty_key, dtype=key_dtype, name="empty_key")
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
- if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
- values.set_shape(
- tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
- self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- # pylint: disable=protected-access
- lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
- # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
+ values = ops.convert_to_tensor(
+ values, dtype=self._value_dtype, name="values")
with ops.colocate_with(self._table_ref):
op = gen_lookup_ops.lookup_table_insert_v2(
self._table_ref, keys, values, name=name)
@@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 8d510ede58..6fb5244fc6 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -434,8 +434,10 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
exported_keys, exported_values = table.export()
- self.assertAllEqual([None], exported_keys.get_shape().as_list())
- self.assertAllEqual([None, 2], exported_values.get_shape().as_list())
+ self.assertAllEqual([None], exported_keys.get_shape().as_list(),
+ msg="Saw shape %s" % exported_keys.shape)
+ self.assertAllEqual([None, 2], exported_values.get_shape().as_list(),
+ msg="Saw shape %s" % exported_values.shape)
# exported data is in the order of the internal map, i.e. undefined
sorted_keys = np.sort(exported_keys.eval())
sorted_values = np.sort(exported_values.eval())
@@ -669,7 +671,7 @@ class MutableHashTableOpTest(test.TestCase):
# lookup with keys of the wrong type
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.lookup(input_string).eval()
# default value of the wrong type
@@ -853,7 +855,8 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3, 4], output.get_shape())
+ self.assertAllEqual(
+ [3, 4], output.shape, msg="Saw shape: %s" % output.shape)
result = output.eval()
self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],