aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yleon@google.com>2016-12-12 13:51:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-12 14:04:51 -0800
commit55735379ccda8a64e49717e95e9e0915e7b8dc8e (patch)
treef436bcceccac86c68885087329b7821c9bb2bc88
parent61eff533884e52641515232ef812ec33e8bd58ea (diff)
Add string_to_index_table, which returns a lookup table that matches strings to indices.
Indices are defined by a vocabulary file or a mapping tensor. It also supports the assignation of hash buckets to out-of-vocabulary terms. Deprecate string_to_index in favor of string_to_index_table. Change: 141806125
-rw-r--r--tensorflow/contrib/lookup/__init__.py3
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py205
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py159
3 files changed, 351 insertions, 16 deletions
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index 8717c3fa2b..99bf5fafb8 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -15,9 +15,12 @@
"""Ops for lookup operations.
@@string_to_index
+@@string_to_index_table_from_file
+@@string_to_index_table_from_tensor
@@index_to_string
@@LookupInterface
@@InitializableLookupTableBase
+@@IdTableWithHashBuckets
@@HashTable
@@MutableHashTable
@@TableInitializerBase
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 20d2627a67..48570dfeb9 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util import compat
+from tensorflow.python.util.deprecation import deprecated
class LookupInterface(object):
@@ -775,6 +776,191 @@ class IdTableWithHashBuckets(LookupInterface):
return ids
+def string_to_index_table_from_file(vocabulary_file=None,
+ num_oov_buckets=0,
+ vocab_size=None,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ name=None):
+ """Returns a lookup table that converts a string tensor into int64 IDs.
+
+ This operation constructs a lookup table to convert tensor of strings into
+ int64 IDs. The mapping can be initialized from a vocabulary file specified in
+ `vocabulary_file`, where the whole line is the key and the zero-based line
+ number is the ID.
+
+ Any lookup of an out-of-vocabulary token will return a bucket ID based on its
+ hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
+ `default_value`.
+ The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`.
+
+ The underlying table must be initialized by calling
+ `tf.initialize_all_tables.run()` or `table.init.run()` once.
+
+ Sample Usages:
+
+ If we have a vocabulary file "test.txt" with the following content:
+
+ ```
+ emerson
+ lake
+ palmer
+ ```
+
+ ```python
+ features = tf.constant(["emerson", "lake", "and", "palmer"])
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file="test.txt", num_oov_buckets=1)
+ ids = table.lookup(features)
+ ...
+ tf.initialize_all_tables().run()
+
+ ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket
+ ```
+
+ Args:
+ vocabulary_file: The vocabulary filename.
+ num_oov_buckets: The number of out-of-vocabulary buckets.
+ vocab_size: Number of the elements in the vocabulary, if known.
+ default_value: The value to use for out-of-vocabulary feature values.
+ Defaults to -1.
+ hasher_spec: A `HasherSpec` to specify the hash function to use for
+ assignation of out-of-vocabulary buckets.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map a string `Tensor` to index `int64` `Tensor`.
+
+ Raises:
+ ValueError: If `vocabulary_file` is not set.
+ ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
+ than zero.
+ """
+ if not vocabulary_file:
+ raise ValueError("vocabulary_file must be specified.")
+ if num_oov_buckets < 0:
+ raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
+ % num_oov_buckets)
+ if vocab_size is not None and vocab_size < 1:
+ raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
+
+ with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
+ table = None
+ shared_name = ""
+ with ops.name_scope(None, "hash_table") as hash_table_scope:
+ if vocab_size:
+ # Keep the shared_name:
+ # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size,
+ TextFileIndex.WHOLE_LINE,
+ TextFileIndex.LINE_NUMBER)
+ else:
+ # Keep the shared_name
+ # <table_type>_<filename>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%s_%s" % (vocabulary_file,
+ TextFileIndex.WHOLE_LINE,
+ TextFileIndex.LINE_NUMBER)
+ init = TextFileIdTableInitializer(
+ vocabulary_file, vocab_size=vocab_size, name="table_init")
+
+ table = HashTable(
+ init, default_value, shared_name=shared_name, name=hash_table_scope)
+ if num_oov_buckets:
+ table = IdTableWithHashBuckets(
+ table,
+ num_oov_buckets=num_oov_buckets,
+ hasher_spec=hasher_spec,
+ name=feat_to_id_scope)
+
+ return table
+
+
+def string_to_index_table_from_tensor(mapping,
+ num_oov_buckets=0,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ name=None):
+ """Returns a lookup table that converts a string tensor into int64 IDs.
+
+ This operation constructs a lookup table to convert tensor of strings into
+ int64 IDs. The mapping can be initialized from a string `mapping` 1-D tensor
+ where each element is a key and corresponding index within the tensor is the
+ value.
+
+ Any lookup of an out-of-vocabulary token will return a bucket ID based on its
+ hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
+ `default_value`.
+ The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`.
+
+ The underlying table must be initialized by calling
+ `tf.initialize_all_tables.run()` or `table.init.run()` once.
+
+ Elements in `mapping` cannot have duplicates, otherwise when executing the
+ table initializer op, it will throw a `FailedPreconditionError`.
+
+ Sample Usages:
+
+ ```python
+ mapping_strings = t.constant(["emerson", "lake", "palmer")
+ table = tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=mapping_strings, num_oov_buckets=1, default_value=-1)
+ features = tf.constant(["emerson", "lake", "and", "palmer"])
+ ids = table.lookup(features)
+ ...
+ tf.initialize_all_tables().run()
+
+ ids.eval() ==> [0, 1, 4, 2]
+ ```
+
+ Args:
+ mapping: A 1-D string `Tensor` that specifies the mapping of strings to
+ indices.
+ num_oov_buckets: The number of out-of-vocabulary buckets.
+ default_value: The value to use for out-of-vocabulary feature values.
+ Defaults to -1.
+ hasher_spec: A `HasherSpec` to specify the hash function to use for
+ assignation of out-of-vocabulary buckets.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map a string `Tensor` to index `int64` `Tensor`.
+
+ Raises:
+ ValueError: `mapping` is invalid.
+ ValueError: If `num_oov_buckets` is negative.
+ """
+ if mapping is None:
+ raise ValueError("mapping must be specified.")
+
+ if num_oov_buckets < 0:
+ raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
+ % num_oov_buckets)
+
+ with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
+ keys = ops.convert_to_tensor(mapping, dtypes.string)
+ num_elements = array_ops.size(keys)
+ values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
+
+ shared_name = ""
+ with ops.name_scope(None, "hash_table") as hash_table_scope:
+ init = KeyValueTensorInitializer(
+ keys, values, dtypes.string, dtypes.int64, name="table_init")
+ table = HashTable(
+ init, default_value, shared_name=shared_name, name=hash_table_scope)
+ if num_oov_buckets:
+ table = IdTableWithHashBuckets(
+ table,
+ num_oov_buckets=num_oov_buckets,
+ hasher_spec=hasher_spec,
+ name=feat_to_id_scope)
+
+ return table
+
+
+@deprecated(
+ "2017-01-07", "This op will be removed after the deprecation date. "
+ "Please switch to string_to_index_table_from_tensor and call the lookup "
+ "method of the returned table.")
def string_to_index(tensor, mapping, default_value=-1, name=None):
"""Maps `tensor` of strings into `int64` indices based on `mapping`.
@@ -816,22 +1002,9 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
The mapped indices. It has the same shape and tensor type (dense or sparse)
as `tensor`.
"""
- with ops.name_scope(name, "string_to_index", [tensor]) as scope:
- shared_name = ""
- keys = ops.convert_to_tensor(mapping, dtypes.string)
- vocab_size = array_ops.size(keys)
- values = math_ops.cast(math_ops.range(vocab_size), dtypes.int64)
- init = KeyValueTensorInitializer(keys,
- values,
- dtypes.string,
- dtypes.int64,
- name="table_init")
-
- t = HashTable(init,
- default_value,
- shared_name=shared_name,
- name="hash_table")
- return t.lookup(tensor, name=scope)
+ table = string_to_index_table_from_tensor(
+ mapping=mapping, default_value=default_value, name=name)
+ return table.lookup(tensor)
def index_to_string(tensor, mapping, default_value="UNK", name=None):
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 940b624161..40a6d4d70e 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1123,6 +1123,165 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
self.assertAllEqual(0, table2.size().eval())
+class StringToIndexTableFromFile(tf.test.TestCase):
+
+ def _createVocabFile(self, basename):
+ vocabulary_file = os.path.join(self.get_temp_dir(), basename)
+ with open(vocabulary_file, "w") as f:
+ f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
+ return vocabulary_file
+
+ def test_string_to_index_table_from_file(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, num_oov_buckets=1)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_string_to_index_table_from_file_with_default_value(self):
+ default_value = -42
+ vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, default_value=default_value)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, default_value), ids.eval())
+
+ def test_string_to_index_table_from_file_with_oov_buckets(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab3.txt")
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, num_oov_buckets=1000)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus", "toccata"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual(
+ (
+ 1, # From vocabulary file.
+ 2, # From vocabulary file.
+ 867, # 3 + fingerprint("tarkus") mod 300.
+ 860), # 3 + fingerprint("toccata") mod 300.
+ ids.eval())
+
+ def test_string_to_index_table_from_file_with_only_oov_buckets(self):
+ self.assertRaises(
+ ValueError,
+ tf.contrib.lookup.string_to_index_table_from_file,
+ vocabulary_file=None)
+
+ def test_string_to_index_table_from_file_with_vocab_size_too_small(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, vocab_size=2)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, -1, -1), ids.eval())
+ self.assertEqual(2, table.size().eval())
+
+ def test_string_to_index_table_from_file_with_vocab_size_too_large(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, vocab_size=4)
+ self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+ "Invalid vocab_size", table.init.run)
+
+ def test_string_to_index_table_from_file_with_vocab_size(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
+
+ self.assertRaises(
+ ValueError,
+ tf.contrib.lookup.string_to_index_table_from_file,
+ vocabulary_file=vocabulary_file,
+ vocab_size=0)
+
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file, vocab_size=3)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, -1), ids.eval())
+ self.assertEqual(3, table.size().eval())
+
+ def test_string_to_index_table_from_file_with_invalid_hashers(self):
+ vocabulary_file = self._createVocabFile("invalid_hasher.txt")
+ with self.test_session():
+ with self.assertRaises(TypeError):
+ tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file,
+ vocab_size=3,
+ num_oov_buckets=1,
+ hasher_spec=1)
+
+ table = tf.contrib.lookup.string_to_index_table_from_file(
+ vocabulary_file=vocabulary_file,
+ vocab_size=3,
+ num_oov_buckets=1,
+ hasher_spec=tf.contrib.lookup.HasherSpec("my-awesome-hash", None))
+
+ self.assertRaises(ValueError, table.lookup,
+ tf.constant(["salad", "surgery", "tarkus"]))
+
+
+class StringToIndexTableFromTensor(tf.test.TestCase):
+
+ def test_string_to_index_table_from_tensor_with_tensor_init(self):
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=["brain", "salad", "surgery"], num_oov_buckets=1)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_string_to_index_table_from_tensor_with_default_value(self):
+ default_value = -42
+ with self.test_session():
+ table = tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=["brain", "salad", "surgery"], default_value=default_value)
+ ids = table.lookup(tf.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(tf.OpError, ids.eval)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, default_value), ids.eval())
+
+ def test_string_to_index_table_from_tensor_with_only_oov_buckets(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=None, num_oov_buckets=1)
+
+ def test_string_to_index_table_from_tensor_with_invalid_hashers(self):
+ with self.test_session():
+ with self.assertRaises(TypeError):
+ tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=["brain", "salad", "surgery"],
+ num_oov_buckets=1,
+ hasher_spec=1)
+
+ table = tf.contrib.lookup.string_to_index_table_from_tensor(
+ mapping=["brain", "salad", "surgery"],
+ num_oov_buckets=1,
+ hasher_spec=tf.contrib.lookup.HasherSpec("my-awesome-hash", None))
+
+ self.assertRaises(ValueError, table.lookup,
+ tf.constant(["salad", "surgery", "tarkus"]))
+
+
class StringToIndexTest(tf.test.TestCase):
def test_string_to_index(self):