diff options
author | Yutaka Leon <yleon@google.com> | 2016-12-06 15:13:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-06 16:12:00 -0800 |
commit | a122c5146250fda93db5b28baf28c88ee56f8355 (patch) | |
tree | 05c7e0f37c893124d9437a21c0d58d359fdcf4a2 | |
parent | 89d31de81ccdf1c70e9bfb7fb9881db2d378e609 (diff) |
Minor updates in lookup table names scopes.
Change: 141233623
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 5047d8d87b..2e449afcfa 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -141,11 +141,11 @@ class InitializableLookupTableBase(LookupInterface): Returns: A scalar tensor containing the number of elements in this table. """ - if name is None: - name = "%s_Size" % self._name - # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) - # pylint: enable=protected-access + with ops.name_scope(name, "%s_Size" % self._name, + [self._table_ref]) as scope: + # pylint: disable=protected-access + return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope) + # pylint: enable=protected-access def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -163,9 +163,6 @@ class InitializableLookupTableBase(LookupInterface): TypeError: when `keys` or `default_value` doesn't match the table data types. """ - if name is None: - name = "%s_lookup_table_find" % self._name - key_tensor = keys if isinstance(keys, sparse_tensor.SparseTensor): key_tensor = keys.values @@ -174,12 +171,12 @@ class InitializableLookupTableBase(LookupInterface): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find(self._table_ref, - key_tensor, - self._default_value, - name=name) - # pylint: enable=protected-access + with ops.name_scope(name, "%s_Lookup" % self._name, + [self._table_ref]) as scope: + # pylint: disable=protected-access + values = gen_data_flow_ops._lookup_table_find( + self._table_ref, key_tensor, self._default_value, name=scope) + # pylint: enable=protected-access values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): @@ -220,13 +217,13 @@ class HashTable(InitializableLookupTableBase): Returns: A `HashTable` object. """ - with ops.name_scope(name, "hash_table", [initializer]): + with ops.name_scope(name, "hash_table", [initializer]) as scope: # pylint: disable=protected-access table_ref = gen_data_flow_ops._hash_table( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, - name=name) + name=scope) # pylint: enable=protected-access super(HashTable, self).__init__(table_ref, default_value, initializer) |