aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yleon@google.com>2016-12-06 15:13:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-06 16:12:00 -0800
commita122c5146250fda93db5b28baf28c88ee56f8355 (patch)
tree05c7e0f37c893124d9437a21c0d58d359fdcf4a2
parent89d31de81ccdf1c70e9bfb7fb9881db2d378e609 (diff)
Minor updates in lookup table names scopes.
Change: 141233623
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py29
1 files changed, 13 insertions, 16 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 5047d8d87b..2e449afcfa 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -141,11 +141,11 @@ class InitializableLookupTableBase(LookupInterface):
Returns:
A scalar tensor containing the number of elements in this table.
"""
- if name is None:
- name = "%s_Size" % self._name
- # pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
- # pylint: enable=protected-access
+ with ops.name_scope(name, "%s_Size" % self._name,
+ [self._table_ref]) as scope:
+ # pylint: disable=protected-access
+ return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope)
+ # pylint: enable=protected-access
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -163,9 +163,6 @@ class InitializableLookupTableBase(LookupInterface):
TypeError: when `keys` or `default_value` doesn't match the table data
types.
"""
- if name is None:
- name = "%s_lookup_table_find" % self._name
-
key_tensor = keys
if isinstance(keys, sparse_tensor.SparseTensor):
key_tensor = keys.values
@@ -174,12 +171,12 @@ class InitializableLookupTableBase(LookupInterface):
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
- # pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(self._table_ref,
- key_tensor,
- self._default_value,
- name=name)
- # pylint: enable=protected-access
+ with ops.name_scope(name, "%s_Lookup" % self._name,
+ [self._table_ref]) as scope:
+ # pylint: disable=protected-access
+ values = gen_data_flow_ops._lookup_table_find(
+ self._table_ref, key_tensor, self._default_value, name=scope)
+ # pylint: enable=protected-access
values.set_shape(key_tensor.get_shape())
if isinstance(keys, sparse_tensor.SparseTensor):
@@ -220,13 +217,13 @@ class HashTable(InitializableLookupTableBase):
Returns:
A `HashTable` object.
"""
- with ops.name_scope(name, "hash_table", [initializer]):
+ with ops.name_scope(name, "hash_table", [initializer]) as scope:
# pylint: disable=protected-access
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
- name=name)
+ name=scope)
# pylint: enable=protected-access
super(HashTable, self).__init__(table_ref, default_value, initializer)