aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup/lookup_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lookup/lookup_ops.py')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py33
1 files changed, 8 insertions, 25 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 4942d94176..8c0bfefb30 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import lookup_ops
# pylint: disable=unused-import
@@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
-
- values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface):
with ops.colocate_with(self._table_ref):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
-
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):
@@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface):
ValueError: If checkpoint is True and no name was specified.
"""
self._default_value = ops.convert_to_tensor(
- default_value, dtype=value_dtype)
+ default_value, dtype=value_dtype, name="default_value")
self._value_shape = self._default_value.get_shape()
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
use_node_name_sharing = checkpoint and shared_name is None
- empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
+ empty_key = ops.convert_to_tensor(
+ empty_key, dtype=key_dtype, name="empty_key")
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
- if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
- values.set_shape(
- tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
- self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- # pylint: disable=protected-access
- lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
- # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
+ values = ops.convert_to_tensor(
+ values, dtype=self._value_dtype, name="values")
with ops.colocate_with(self._table_ref):
op = gen_lookup_ops.lookup_table_insert_v2(
self._table_ref, keys, values, name=name)
@@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):