aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yutaka.leon@gmail.com>2016-03-15 18:41:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-16 16:19:52 -0700
commitfdc6752cda33e8d5879e4db68093eca7d7395988 (patch)
tree4599185470b173e94ab49f64341ccddc460998ac
parente06a4a2e1dc0dc6f36f63f13bb048c8cbb2b2c9c (diff)
Add tf.contrib.lookup.string_to_index and tf.contrib.lookup.index_to_string to map strings to IDs and viceversa.
Change: 117303981
-rw-r--r--tensorflow/contrib/lookup/__init__.py2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py115
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py76
3 files changed, 193 insertions, 0 deletions
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index afb5cd0528..4a3550b310 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -14,6 +14,8 @@
# ==============================================================================
"""Ops for lookup operations.
+@@string_to_index
+@@index_to_string
@@LookupInterface
@@InitializableLookupTableBase
@@HashTable
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 6c9075f615..acc3fc2686 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -21,7 +21,9 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import math_ops
class LookupInterface(object):
@@ -301,3 +303,116 @@ class KeyValueTensorInitializer(TableInitializerBase):
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
+
+
+def string_to_index(tensor, mapping, default_value=-1, name=None):
+ """Maps `tensor` of strings into `int64` indices based on `mapping`.
+
+ This operation converts `tensor` of strings into `int64` indices.
+ The mapping is initialized from a string `mapping` tensor where each element
+ is a key and corresponding index within the tensor is the value.
+
+ Any entry in the input which does not have a corresponding entry in 'mapping'
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ Elements in `mapping` cannot be duplicated, otherwise the initialization
+ will throw a FailedPreconditionError.
+
+ The underlying table must be initialized by calling
+ `tf.initialize_all_tables.run()` once.
+
+ For example:
+
+ ```python
+ mapping_strings = t.constant(["emerson", "lake", "palmer")
+ feats = tf.constant(["emerson", "lake", "and", "palmer"])
+ ids = tf.contrib.lookup.string_to_index(
+ feats, mapping=mapping_strings, default_value=-1)
+ ...
+ tf.initialize_all_tables().run()
+
+ ids.eval() ==> [0, 1, -1, 2]
+ ```
+
+ Args:
+ tensor: A 1-D input `Tensor` with the strings to map to indices.
+ mapping: A 1-D string `Tensor` that specifies the mapping of strings to
+ indices.
+ default_value: The `int64` value to use for out-of-vocabulary strings.
+ Defaults to -1.
+ name: A name for this op (optional).
+
+ Returns:
+ The mapped indices. It has the same shape and tensor type (dense or sparse)
+ as `tensor`.
+ """
+ with ops.op_scope([tensor], name, "string_to_index") as scope:
+ shared_name = ""
+ keys = ops.convert_to_tensor(mapping, dtypes.string)
+ vocab_size = array_ops.size(keys)
+ values = math_ops.cast(math_ops.range(vocab_size), dtypes.int64)
+ init = KeyValueTensorInitializer(keys,
+ values,
+ dtypes.string,
+ dtypes.int64,
+ name="table_init")
+
+ t = HashTable(init,
+ default_value,
+ shared_name=shared_name,
+ name="hash_table")
+ return t.lookup(tensor, name=scope)
+
+
+def index_to_string(tensor, mapping, default_value="UNK", name=None):
+ """Maps `tensor` of indices into string values based on `mapping`.
+
+ This operation converts `int64` indices into string values. The mapping is
+ initialized from a string `mapping` tensor where each element is a value and
+ the corresponding index within the tensor is the key.
+
+ Any input which does not have a corresponding index in 'mapping'
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ The underlying table must be initialized by calling
+ `tf.initialize_all_tables.run()` once.
+
+ For example:
+
+ ```python
+ mapping_string = t.constant(["emerson", "lake", "palmer")
+ indices = tf.constant([1, 5], tf.int64)
+ values = tf.contrib.lookup.index_to_string(
+ indices, mapping=mapping_string, default_value="UNKNOWN")
+ ...
+ tf.initialize_all_tables().run()
+
+ values.eval() ==> ["lake", "UNKNOWN"]
+ ```
+
+ Args:
+ indices: A `int64` `Tensor` with the indices to map to strings.
+ mapping: A 1-D string `Tensor` that specifies the strings to map from
+ indices.
+ default_value: The string value to use for out-of-vocabulary indices.
+ name: A name for this op (optional).
+
+ Returns:
+ The strings values associated to the indices. The resultant dense
+ feature value tensor has the same shape as the corresponding `indices`.
+ """
+ with ops.op_scope([tensor], name, "index_to_string") as scope:
+ shared_name = ""
+ values = ops.convert_to_tensor(mapping, dtypes.string)
+ vocab_size = array_ops.size(values)
+ keys = math_ops.cast(math_ops.range(vocab_size), dtypes.int64)
+ init = KeyValueTensorInitializer(keys,
+ values,
+ dtypes.int64,
+ dtypes.string,
+ name="table_init")
+ t = HashTable(init,
+ default_value,
+ shared_name=shared_name,
+ name="hash_table")
+ return t.lookup(tensor, name=scope)
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 88ff91c92c..645cd7a82b 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -235,5 +235,81 @@ class HashTableOpTest(tf.test.TestCase):
values), default_val)
+class StringToIndexTest(tf.test.TestCase):
+
+ def test_string_to_index(self):
+ with self.test_session():
+ mapping_strings = tf.constant(["brain", "salad", "surgery"])
+ feats = tf.constant(["salad", "surgery", "tarkus"])
+ indices = tf.contrib.lookup.string_to_index(feats,
+ mapping=mapping_strings)
+
+ self.assertRaises(tf.OpError, indices.eval)
+ tf.initialize_all_tables().run()
+
+ self.assertAllEqual((1, 2, -1), indices.eval())
+
+ def test_duplicate_entries(self):
+ with self.test_session():
+ mapping_strings = tf.constant(["hello", "hello"])
+ feats = tf.constant(["hello", "hola"])
+ indices = tf.contrib.lookup.string_to_index(feats,
+ mapping=mapping_strings)
+
+ self.assertRaises(tf.OpError, tf.initialize_all_tables().run)
+
+ def test_string_to_index_with_default_value(self):
+ default_value = -42
+ with self.test_session():
+ mapping_strings = tf.constant(["brain", "salad", "surgery"])
+ feats = tf.constant(["salad", "surgery", "tarkus"])
+ indices = tf.contrib.lookup.string_to_index(feats,
+ mapping=mapping_strings,
+ default_value=default_value)
+ self.assertRaises(tf.OpError, indices.eval)
+
+ tf.initialize_all_tables().run()
+ self.assertAllEqual((1, 2, default_value), indices.eval())
+
+
+class IndexToStringTest(tf.test.TestCase):
+
+ def test_index_to_string(self):
+ with self.test_session():
+ mapping_strings = tf.constant(["brain", "salad", "surgery"])
+ indices = tf.constant([0, 1, 2, 3], tf.int64)
+ feats = tf.contrib.lookup.index_to_string(indices,
+ mapping=mapping_strings)
+
+ self.assertRaises(tf.OpError, feats.eval)
+ tf.initialize_all_tables().run()
+
+ self.assertAllEqual(("brain", "salad", "surgery", "UNK"), feats.eval())
+
+ def test_duplicate_entries(self):
+ with self.test_session():
+ mapping_strings = tf.constant(["hello", "hello"])
+ indices = tf.constant([0, 1, 4], tf.int64)
+ feats = tf.contrib.lookup.index_to_string(indices,
+ mapping=mapping_strings)
+ tf.initialize_all_tables().run()
+ self.assertAllEqual(("hello", "hello", "UNK"), feats.eval())
+
+ self.assertRaises(tf.OpError, tf.initialize_all_tables().run)
+
+ def test_index_to_string_with_default_value(self):
+ default_value = "NONE"
+ with self.test_session():
+ mapping_strings = tf.constant(["brain", "salad", "surgery"])
+ indices = tf.constant([1, 2, 4], tf.int64)
+ feats = tf.contrib.lookup.index_to_string(indices,
+ mapping=mapping_strings,
+ default_value=default_value)
+ self.assertRaises(tf.OpError, feats.eval)
+
+ tf.initialize_all_tables().run()
+ self.assertAllEqual(("salad", "surgery", default_value), feats.eval())
+
+
if __name__ == "__main__":
tf.test.main()