From fdc6752cda33e8d5879e4db68093eca7d7395988 Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Tue, 15 Mar 2016 18:41:57 -0800 Subject: Add tf.contrib.lookup.string_to_index and tf.contrib.lookup.index_to_string to map strings to IDs and viceversa. Change: 117303981 --- tensorflow/contrib/lookup/__init__.py | 2 + tensorflow/contrib/lookup/lookup_ops.py | 115 +++++++++++++++++++++++++++ tensorflow/contrib/lookup/lookup_ops_test.py | 76 ++++++++++++++++++ 3 files changed, 193 insertions(+) diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py index afb5cd0528..4a3550b310 100644 --- a/tensorflow/contrib/lookup/__init__.py +++ b/tensorflow/contrib/lookup/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for lookup operations. +@@string_to_index +@@index_to_string @@LookupInterface @@InitializableLookupTableBase @@HashTable diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 6c9075f615..acc3fc2686 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -21,7 +21,9 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import math_ops class LookupInterface(object): @@ -301,3 +303,116 @@ class KeyValueTensorInitializer(TableInitializerBase): # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op + + +def string_to_index(tensor, mapping, default_value=-1, name=None): + """Maps `tensor` of strings into `int64` indices based on `mapping`. + + This operation converts `tensor` of strings into `int64` indices. + The mapping is initialized from a string `mapping` tensor where each element + is a key and corresponding index within the tensor is the value. + + Any entry in the input which does not have a corresponding entry in 'mapping' + (an out-of-vocabulary entry) is assigned the `default_value` + + Elements in `mapping` cannot be duplicated, otherwise the initialization + will throw a FailedPreconditionError. + + The underlying table must be initialized by calling + `tf.initialize_all_tables.run()` once. + + For example: + + ```python + mapping_strings = t.constant(["emerson", "lake", "palmer") + feats = tf.constant(["emerson", "lake", "and", "palmer"]) + ids = tf.contrib.lookup.string_to_index( + feats, mapping=mapping_strings, default_value=-1) + ... + tf.initialize_all_tables().run() + + ids.eval() ==> [0, 1, -1, 2] + ``` + + Args: + tensor: A 1-D input `Tensor` with the strings to map to indices. + mapping: A 1-D string `Tensor` that specifies the mapping of strings to + indices. + default_value: The `int64` value to use for out-of-vocabulary strings. + Defaults to -1. + name: A name for this op (optional). + + Returns: + The mapped indices. It has the same shape and tensor type (dense or sparse) + as `tensor`. + """ + with ops.op_scope([tensor], name, "string_to_index") as scope: + shared_name = "" + keys = ops.convert_to_tensor(mapping, dtypes.string) + vocab_size = array_ops.size(keys) + values = math_ops.cast(math_ops.range(vocab_size), dtypes.int64) + init = KeyValueTensorInitializer(keys, + values, + dtypes.string, + dtypes.int64, + name="table_init") + + t = HashTable(init, + default_value, + shared_name=shared_name, + name="hash_table") + return t.lookup(tensor, name=scope) + + +def index_to_string(tensor, mapping, default_value="UNK", name=None): + """Maps `tensor` of indices into string values based on `mapping`. + + This operation converts `int64` indices into string values. The mapping is + initialized from a string `mapping` tensor where each element is a value and + the corresponding index within the tensor is the key. + + Any input which does not have a corresponding index in 'mapping' + (an out-of-vocabulary entry) is assigned the `default_value` + + The underlying table must be initialized by calling + `tf.initialize_all_tables.run()` once. + + For example: + + ```python + mapping_string = t.constant(["emerson", "lake", "palmer") + indices = tf.constant([1, 5], tf.int64) + values = tf.contrib.lookup.index_to_string( + indices, mapping=mapping_string, default_value="UNKNOWN") + ... + tf.initialize_all_tables().run() + + values.eval() ==> ["lake", "UNKNOWN"] + ``` + + Args: + indices: A `int64` `Tensor` with the indices to map to strings. + mapping: A 1-D string `Tensor` that specifies the strings to map from + indices. + default_value: The string value to use for out-of-vocabulary indices. + name: A name for this op (optional). + + Returns: + The strings values associated to the indices. The resultant dense + feature value tensor has the same shape as the corresponding `indices`. + """ + with ops.op_scope([tensor], name, "index_to_string") as scope: + shared_name = "" + values = ops.convert_to_tensor(mapping, dtypes.string) + vocab_size = array_ops.size(values) + keys = math_ops.cast(math_ops.range(vocab_size), dtypes.int64) + init = KeyValueTensorInitializer(keys, + values, + dtypes.int64, + dtypes.string, + name="table_init") + t = HashTable(init, + default_value, + shared_name=shared_name, + name="hash_table") + return t.lookup(tensor, name=scope) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 88ff91c92c..645cd7a82b 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -235,5 +235,81 @@ class HashTableOpTest(tf.test.TestCase): values), default_val) +class StringToIndexTest(tf.test.TestCase): + + def test_string_to_index(self): + with self.test_session(): + mapping_strings = tf.constant(["brain", "salad", "surgery"]) + feats = tf.constant(["salad", "surgery", "tarkus"]) + indices = tf.contrib.lookup.string_to_index(feats, + mapping=mapping_strings) + + self.assertRaises(tf.OpError, indices.eval) + tf.initialize_all_tables().run() + + self.assertAllEqual((1, 2, -1), indices.eval()) + + def test_duplicate_entries(self): + with self.test_session(): + mapping_strings = tf.constant(["hello", "hello"]) + feats = tf.constant(["hello", "hola"]) + indices = tf.contrib.lookup.string_to_index(feats, + mapping=mapping_strings) + + self.assertRaises(tf.OpError, tf.initialize_all_tables().run) + + def test_string_to_index_with_default_value(self): + default_value = -42 + with self.test_session(): + mapping_strings = tf.constant(["brain", "salad", "surgery"]) + feats = tf.constant(["salad", "surgery", "tarkus"]) + indices = tf.contrib.lookup.string_to_index(feats, + mapping=mapping_strings, + default_value=default_value) + self.assertRaises(tf.OpError, indices.eval) + + tf.initialize_all_tables().run() + self.assertAllEqual((1, 2, default_value), indices.eval()) + + +class IndexToStringTest(tf.test.TestCase): + + def test_index_to_string(self): + with self.test_session(): + mapping_strings = tf.constant(["brain", "salad", "surgery"]) + indices = tf.constant([0, 1, 2, 3], tf.int64) + feats = tf.contrib.lookup.index_to_string(indices, + mapping=mapping_strings) + + self.assertRaises(tf.OpError, feats.eval) + tf.initialize_all_tables().run() + + self.assertAllEqual(("brain", "salad", "surgery", "UNK"), feats.eval()) + + def test_duplicate_entries(self): + with self.test_session(): + mapping_strings = tf.constant(["hello", "hello"]) + indices = tf.constant([0, 1, 4], tf.int64) + feats = tf.contrib.lookup.index_to_string(indices, + mapping=mapping_strings) + tf.initialize_all_tables().run() + self.assertAllEqual(("hello", "hello", "UNK"), feats.eval()) + + self.assertRaises(tf.OpError, tf.initialize_all_tables().run) + + def test_index_to_string_with_default_value(self): + default_value = "NONE" + with self.test_session(): + mapping_strings = tf.constant(["brain", "salad", "surgery"]) + indices = tf.constant([1, 2, 4], tf.int64) + feats = tf.contrib.lookup.index_to_string(indices, + mapping=mapping_strings, + default_value=default_value) + self.assertRaises(tf.OpError, feats.eval) + + tf.initialize_all_tables().run() + self.assertAllEqual(("salad", "surgery", default_value), feats.eval()) + + if __name__ == "__main__": tf.test.main() -- cgit v1.2.3