diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-10 09:33:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-10 09:50:17 -0800 |
commit | de5f2aa8194eef27c647c369178d4821574b2622 (patch) | |
tree | a01f776a5f67911874f0bdc5a2e716981ba2cc1a /tensorflow/contrib/lookup | |
parent | 3af39a00d151bd55b69c8b045a6e67284c22c9f5 (diff) |
Support integer sparse feature values.
Add some additional values args to name_scope calls.
Change: 149765769
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/__init__.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 161 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 360 |
3 files changed, 434 insertions, 91 deletions
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py index e743832e80..dbd64cf042 100644 --- a/tensorflow/contrib/lookup/__init__.py +++ b/tensorflow/contrib/lookup/__init__.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# TODO(ptucker): deprecate string_to_index_table_from_file and +# string_to_index_table_from_tensor 2017-04-10. """Ops for lookup operations. @@string_to_index @@string_to_index_table_from_file @@string_to_index_table_from_tensor +@@index_table_from_file +@@index_table_from_tensor @@index_to_string @@index_to_string_table_from_file @@index_to_string_table_from_tensor diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 5b4f5cee2d..6a20ee4440 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -178,8 +178,9 @@ class InitializableLookupTableBase(LookupInterface): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_Lookup" % self._name, - [self._table_ref]) as scope: + with ops.name_scope( + name, "%s_Lookup" % self._name, + (self._table_ref, key_tensor, self._default_value)) as scope: # pylint: disable=protected-access values = gen_data_flow_ops._lookup_table_find( self._table_ref, key_tensor, self._default_value, name=scope) @@ -215,7 +216,8 @@ class HashTable(InitializableLookupTableBase): the table will be immutable. Args: - initializer: The table initializer to use. + initializer: The table initializer to use. See `HashTable` kernel for + supported key and value types. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. @@ -224,7 +226,8 @@ class HashTable(InitializableLookupTableBase): Returns: A `HashTable` object. """ - with ops.name_scope(name, "hash_table", [initializer]) as scope: + with ops.name_scope( + name, "hash_table", (initializer, default_value)) as scope: # pylint: disable=protected-access table_ref = gen_data_flow_ops._hash_table( shared_name=shared_name, @@ -301,7 +304,9 @@ class KeyValueTensorInitializer(TableInitializerBase): key and value data types. """ table.check_table_dtypes(self._keys.dtype, self._values.dtype) - with ops.name_scope(self._name, values=[table]) as scope: + with ops.name_scope( + self._name, + values=(table.table_ref, self._keys, self._values)) as scope: # pylint: disable=protected-access init_op = gen_data_flow_ops._initialize_table(table.table_ref, self._keys, @@ -422,9 +427,11 @@ class TextFileInitializer(TableInitializerBase): if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." % (dtypes.int64, key_dtype)) - if key_index == TextFileIndex.WHOLE_LINE and key_dtype != dtypes.string: - raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." % - (dtypes.string, key_dtype)) + if ((key_index == TextFileIndex.WHOLE_LINE) and + (not key_dtype.is_integer) and (key_dtype != dtypes.string)): + raise ValueError( + "Signature mismatch. Keys must be integer or string, got %s." % + key_dtype) if value_index < -2: raise ValueError("Invalid value index %s." % (value_index)) @@ -461,7 +468,8 @@ class TextFileInitializer(TableInitializerBase): key and value data types. """ table.check_table_dtypes(self.key_dtype, self.value_dtype) - with ops.name_scope(self._name, "text_file_init", [table]) as scope: + with ops.name_scope( + self._name, "text_file_init", (table.table_ref,)) as scope: filename = ops.convert_to_tensor(self._filename, dtypes.string, name="asset_filepath") @@ -539,7 +547,8 @@ class TextFileIdTableInitializer(TextFileInitializer): value_column_index=TextFileIndex.LINE_NUMBER, vocab_size=None, delimiter="\t", - name="text_file_id_table_init"): + name="text_file_id_table_init", + key_dtype=dtypes.string): """Constructs an initializer for an string-to-id table from a text file. It populates a table that its key and value types are string and int64, @@ -565,13 +574,14 @@ class TextFileIdTableInitializer(TextFileInitializer): vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: Optional name for the op. + key_dtype: The `key` data type. Raises: TypeError: when the filename is empty, or when the table key and value data types do not match the expected data types. """ super(TextFileIdTableInitializer, self).__init__(filename, - dtypes.string, + key_dtype, key_column_index, dtypes.int64, value_column_index, @@ -621,6 +631,12 @@ class StrongHashSpec(HasherSpec): return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) +def _as_string(tensor): + if dtypes.string == tensor.dtype.base_dtype: + return tensor + return string_ops.as_string(tensor) + + class IdTableWithHashBuckets(LookupInterface): """String to Id table wrapper that assigns out-of-vocabulary keys to buckets. @@ -663,15 +679,19 @@ class IdTableWithHashBuckets(LookupInterface): table, num_oov_buckets, hasher_spec=FastHashSpec, - name=None): + name=None, + key_dtype=None): """Construct a `IdTableWithHashBuckets` object. Args: - table: Table that maps string to ids. + table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets (optional). name: A name for the operation (optional). + key_dtype: Data type of keys passed to `lookup`. Defaults to + `table.key_dtype` if `table` is specified, otherwise `tf.string`. + Must be string or integer, and must be castable to `table.key_dtype`. Raises: ValueError: when `table` in None and `num_oov_buckets` is not positive. @@ -682,22 +702,37 @@ class IdTableWithHashBuckets(LookupInterface): if name: name = name.rstrip("/") if table: - table.check_table_dtypes(dtypes.string, dtypes.int64) + if key_dtype is None: + key_dtype = table.key_dtype + supported_table_key_dtypes = (dtypes.int64, dtypes.string) + if table.key_dtype not in supported_table_key_dtypes: + raise TypeError("Invalid key dtype, expected one of %s, but got %s." % + (supported_table_key_dtypes, key_dtype)) + if table.key_dtype.is_integer != key_dtype.is_integer: + raise TypeError("Invalid key dtype, expected %s but got %s." % + ("integer" if key_dtype.is_integer else "non-integer", + table.key_dtype)) + if table.value_dtype != dtypes.int64: + raise TypeError("Invalid value dtype, expected %s but got %s." % + (dtypes.int64, table.value_dtype)) self._table = table name = name or self._table.name else: if num_oov_buckets <= 0: raise ValueError("oov_buckets must be > 0 if no table is supplied.") + key_dtype = dtypes.string if key_dtype is None else key_dtype self._table = None name = name or "hash_bucket" + if (not key_dtype.is_integer) and (dtypes.string != key_dtype): + raise TypeError( + "Invalid key_dtype, expected integer or string, got %s." % key_dtype) self._num_oov_buckets = num_oov_buckets if not isinstance(hasher_spec, HasherSpec): raise TypeError("hasher_spec must be of type HasherSpec, got %s" % hasher_spec) self._hasher_spec = hasher_spec - - super(IdTableWithHashBuckets, self).__init__(dtypes.string, dtypes.int64, + super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, name.split("/")[-1]) @property @@ -748,24 +783,25 @@ class IdTableWithHashBuckets(LookupInterface): if keys.dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - - string_values = keys + values = keys if isinstance(keys, sparse_tensor.SparseTensor): - string_values = keys.values + values = keys.values + if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): + values = math_ops.to_int64(values) if self._num_oov_buckets == 0: - ids = self._table.lookup(string_values, name=name) + ids = self._table.lookup(values, name=name) else: # TODO(yleon): Consider moving this functionality to its own kernel. with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) buckets = str_to_hash_bucket( - string_values, + _as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") if self._table: - ids = self._table.lookup(string_values) + ids = self._table.lookup(values) buckets = math_ops.add(buckets, self._table.size()) is_id_non_default = math_ops.not_equal(ids, self._table.default_value) ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) @@ -776,12 +812,25 @@ class IdTableWithHashBuckets(LookupInterface): return ids +@deprecated("2017-04-10", "Use `index_table_from_file`.") def string_to_index_table_from_file(vocabulary_file=None, num_oov_buckets=0, vocab_size=None, default_value=-1, hasher_spec=FastHashSpec, name=None): + return index_table_from_file( + vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec, + key_dtype=dtypes.string, name=name) + + +def index_table_from_file(vocabulary_file=None, + num_oov_buckets=0, + vocab_size=None, + default_value=-1, + hasher_spec=FastHashSpec, + key_dtype=dtypes.string, + name=None): """Returns a lookup table that converts a string tensor into int64 IDs. This operation constructs a lookup table to convert tensor of strings into @@ -809,7 +858,7 @@ def string_to_index_table_from_file(vocabulary_file=None, ```python features = tf.constant(["emerson", "lake", "and", "palmer"]) - table = tf.contrib.lookup.string_to_index_table_from_file( + table = tf.contrib.lookup.index_table_from_file( vocabulary_file="test.txt", num_oov_buckets=1) ids = table.lookup(features) ... @@ -826,6 +875,7 @@ def string_to_index_table_from_file(vocabulary_file=None, Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets. + key_dtype: The `key` data type. name: A name for this op (optional). Returns: @@ -843,6 +893,8 @@ def string_to_index_table_from_file(vocabulary_file=None, % num_oov_buckets) if vocab_size is not None and vocab_size < 1: raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) + if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): + raise TypeError("Only integer and string keys are supported.") with ops.name_scope(name, "string_to_index") as feat_to_id_scope: table = None @@ -861,7 +913,9 @@ def string_to_index_table_from_file(vocabulary_file=None, TextFileIndex.WHOLE_LINE, TextFileIndex.LINE_NUMBER) init = TextFileIdTableInitializer( - vocabulary_file, vocab_size=vocab_size, name="table_init") + vocabulary_file, vocab_size=vocab_size, + key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, + name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) @@ -870,16 +924,32 @@ def string_to_index_table_from_file(vocabulary_file=None, table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, - name=feat_to_id_scope) + name=feat_to_id_scope, + key_dtype=key_dtype) return table +@deprecated("2017-04-10", "Use `index_table_from_tensor`.") def string_to_index_table_from_tensor(mapping, num_oov_buckets=0, default_value=-1, hasher_spec=FastHashSpec, name=None): + with ops.name_scope(name, "string_to_index") as scope: + mapping = ops.convert_to_tensor(mapping) + if dtypes.string != mapping.dtype.base_dtype: + raise ValueError("string_to_index_table_from_tensor requires string.") + return index_table_from_tensor( + mapping, num_oov_buckets, default_value, hasher_spec, name=scope) + + +def index_table_from_tensor(mapping, + num_oov_buckets=0, + default_value=-1, + hasher_spec=FastHashSpec, + dtype=dtypes.string, + name=None): """Returns a lookup table that converts a string tensor into int64 IDs. This operation constructs a lookup table to convert tensor of strings into @@ -902,7 +972,7 @@ def string_to_index_table_from_tensor(mapping, ```python mapping_strings = t.constant(["emerson", "lake", "palmer") - table = tf.contrib.lookup.string_to_index_table_from_tensor( + table = tf.contrib.lookup.index_table_from_tensor( mapping=mapping_strings, num_oov_buckets=1, default_value=-1) features = tf.constant(["emerson", "lake", "and", "palmer"]) ids = table.lookup(features) @@ -913,20 +983,22 @@ def string_to_index_table_from_tensor(mapping, ``` Args: - mapping: A 1-D string `Tensor` that specifies the mapping of strings to - indices. + mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The + type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for - assignation of out-of-vocabulary buckets. + assignment of out-of-vocabulary buckets. + dtype: The type of values passed to `lookup`. Only string and integers are + supported. name: A name for this op (optional). Returns: - The lookup table to map a string `Tensor` to index `int64` `Tensor`. + The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: - ValueError: `mapping` is invalid. + ValueError: If `mapping` is invalid. ValueError: If `num_oov_buckets` is negative. """ if mapping is None: @@ -936,15 +1008,25 @@ def string_to_index_table_from_tensor(mapping, raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) + if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): + raise TypeError("Only integer and string keys are supported.") + with ops.name_scope(name, "string_to_index") as feat_to_id_scope: - keys = ops.convert_to_tensor(mapping, dtypes.string) + keys = ops.convert_to_tensor(mapping) + if keys.dtype.is_integer != dtype.is_integer: + raise ValueError("Expected %s, got %s." % ( + "integer" if dtype.is_integer else "non-integer", keys.dtype)) + if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): + raise ValueError("Expected %s, got %s." % (dtype, keys.dtype)) num_elements = array_ops.size(keys) - values = math_ops.cast(math_ops.range(num_elements), dtypes.int64) + values = math_ops.to_int64(math_ops.range(num_elements)) shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: + table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( - keys, values, dtypes.string, dtypes.int64, name="table_init") + table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, + name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) if num_oov_buckets: @@ -952,14 +1034,15 @@ def string_to_index_table_from_tensor(mapping, table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, - name=feat_to_id_scope) + name=feat_to_id_scope, + key_dtype=dtype) return table @deprecated( "2017-01-07", "This op will be removed after the deprecation date. " - "Please switch to string_to_index_table_from_tensor and call the lookup " + "Please switch to index_table_from_tensor and call the lookup " "method of the returned table.") def string_to_index(tensor, mapping, default_value=-1, name=None): """Maps `tensor` of strings into `int64` indices based on `mapping`. @@ -1002,7 +1085,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): The mapped indices. It has the same shape and tensor type (dense or sparse) as `tensor`. """ - table = string_to_index_table_from_tensor( + table = index_table_from_tensor( mapping=mapping, default_value=default_value, name=name) return table.lookup(tensor) @@ -1135,7 +1218,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): with ops.name_scope(name, "index_to_string") as scope: values = ops.convert_to_tensor(mapping, dtypes.string) num_elements = array_ops.size(values) - keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64) + keys = math_ops.to_int64(math_ops.range(num_elements)) shared_name = "" init = KeyValueTensorInitializer( @@ -1306,7 +1389,7 @@ class MutableHashTable(LookupInterface): (self._key_dtype, keys.dtype)) with ops.name_scope(name, "%s_lookup_table_find" % self._name, - [self._table_ref, keys]) as name: + (self._table_ref, keys, self._default_value)) as name: # pylint: disable=protected-access values = gen_data_flow_ops._lookup_table_find(self._table_ref, keys, diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 3b38ecab97..fe8fa71981 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1150,18 +1150,18 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual(0, table2.size().eval()) -class StringToIndexTableFromFile(test.TestCase): +class IndexTableFromFile(test.TestCase): - def _createVocabFile(self, basename): + def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): vocabulary_file = os.path.join(self.get_temp_dir(), basename) with open(vocabulary_file, "w") as f: - f.write("\n".join(["brain", "salad", "surgery"]) + "\n") + f.write("\n".join(values) + "\n") return vocabulary_file - def test_string_to_index_table_from_file(self): + def test_string_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1169,11 +1169,39 @@ class StringToIndexTableFromFile(test.TestCase): data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) - def test_string_to_index_table_from_file_with_default_value(self): + def test_int32_index_table_from_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab2.txt", values=("42", "1", "-1000")) + with self.test_session(): + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1, + key_dtype=dtypes.int32) + ids = table.lookup( + constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) + + self.assertRaises(errors_impl.OpError, ids.eval) + data_flow_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_int64_index_table_from_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab3.txt", values=("42", "1", "-1000")) + with self.test_session(): + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1, + key_dtype=dtypes.int64) + ids = table.lookup( + constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) + + self.assertRaises(errors_impl.OpError, ids.eval) + data_flow_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_index_table_from_file_with_default_value(self): default_value = -42 - vocabulary_file = self._createVocabFile("f2i_vocab2.txt") + vocabulary_file = self._createVocabFile("f2i_vocab4.txt") with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1181,10 +1209,10 @@ class StringToIndexTableFromFile(test.TestCase): data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) - def test_string_to_index_table_from_file_with_oov_buckets(self): - vocabulary_file = self._createVocabFile("f2i_vocab3.txt") + def test_index_table_from_file_with_oov_buckets(self): + vocabulary_file = self._createVocabFile("f2i_vocab5.txt") with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1000) ids = table.lookup( constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) @@ -1199,16 +1227,16 @@ class StringToIndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_string_to_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_with_only_oov_buckets(self): self.assertRaises( ValueError, - lookup.string_to_index_table_from_file, + lookup.index_table_from_file, vocabulary_file=None) - def test_string_to_index_table_from_file_with_vocab_size_too_small(self): - vocabulary_file = self._createVocabFile("f2i_vocab5.txt") + def test_index_table_from_file_with_vocab_size_too_small(self): + vocabulary_file = self._createVocabFile("f2i_vocab6.txt") with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1217,25 +1245,25 @@ class StringToIndexTableFromFile(test.TestCase): self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) - def test_string_to_index_table_from_file_with_vocab_size_too_large(self): - vocabulary_file = self._createVocabFile("f2i_vocab6.txt") + def test_index_table_from_file_with_vocab_size_too_large(self): + vocabulary_file = self._createVocabFile("f2i_vocab7.txt") with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", table.init.run) - def test_string_to_index_table_from_file_with_vocab_size(self): - vocabulary_file = self._createVocabFile("f2i_vocab7.txt") + def test_index_table_from_file_with_vocab_size(self): + vocabulary_file = self._createVocabFile("f2i_vocab8.txt") self.assertRaises( ValueError, - lookup.string_to_index_table_from_file, + lookup.index_table_from_file, vocabulary_file=vocabulary_file, vocab_size=0) with self.test_session(): - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1244,17 +1272,17 @@ class StringToIndexTableFromFile(test.TestCase): self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) - def test_string_to_index_table_from_file_with_invalid_hashers(self): + def test_index_table_from_file_with_invalid_hashers(self): vocabulary_file = self._createVocabFile("invalid_hasher.txt") with self.test_session(): with self.assertRaises(TypeError): - lookup.string_to_index_table_from_file( + lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3, num_oov_buckets=1, hasher_spec=1) - table = lookup.string_to_index_table_from_file( + table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3, num_oov_buckets=1, @@ -1264,22 +1292,70 @@ class StringToIndexTableFromFile(test.TestCase): constant_op.constant(["salad", "surgery", "tarkus"])) -class StringToIndexTableFromTensor(test.TestCase): +class KeyValueTensorInitializerTest(test.TestCase): - def test_string_to_index_table_from_tensor_with_tensor_init(self): + def test_string(self): + with ops.Graph().as_default(), self.test_session(): + init = lookup.KeyValueTensorInitializer( + ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) + table = lookup.HashTable(init, default_value=-1) + table.init.run() + + def test_int64(self): + with ops.Graph().as_default(), self.test_session(): + init = lookup.KeyValueTensorInitializer( + (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) + table = lookup.HashTable(init, default_value=-1) + table.init.run() + + def test_int32(self): + with ops.Graph().as_default(), self.test_session(): + init = lookup.KeyValueTensorInitializer( + (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64) + table = lookup.HashTable(init, default_value=-1) + with self.assertRaisesRegexp( + errors_impl.OpError, "No OpKernel was registered"): + table.init.run() + + +class IndexTableFromTensor(test.TestCase): + + def test_index_table_from_tensor_with_tensor_init(self): with self.test_session(): - table = lookup.string_to_index_table_from_tensor( - mapping=["brain", "salad", "surgery"], num_oov_buckets=1) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup.index_table_from_tensor( + mapping=("brain", "salad", "surgery"), num_oov_buckets=1) + ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) + + self.assertRaises(errors_impl.OpError, ids.eval) + data_flow_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_int32_index_table_from_tensor_with_tensor_init(self): + with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) + ids = table.lookup( + constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) + + self.assertRaises(errors_impl.OpError, ids.eval) + data_flow_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_int64_index_table_from_tensor_with_tensor_init(self): + with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) + ids = table.lookup( + constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) - def test_string_to_index_table_from_tensor_with_default_value(self): + def test_index_table_from_tensor_with_default_value(self): default_value = -42 with self.test_session(): - table = lookup.string_to_index_table_from_tensor( + table = lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1287,21 +1363,30 @@ class StringToIndexTableFromTensor(test.TestCase): data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) - def test_string_to_index_table_from_tensor_with_only_oov_buckets(self): + def test_index_table_from_tensor_missing_mapping(self): with self.test_session(): - with self.assertRaises(ValueError): - lookup.string_to_index_table_from_tensor( - mapping=None, num_oov_buckets=1) + with self.assertRaisesRegexp(ValueError, "mapping must be specified"): + lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1) + + def test_index_table_from_tensor_empty_mapping(self): + with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=np.array([], dtype=np.str_), num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) + self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaisesRegexp( + errors_impl.OpError, "keys and values cannot be empty"): + data_flow_ops.tables_initializer().run() - def test_string_to_index_table_from_tensor_with_invalid_hashers(self): + def test_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): with self.assertRaises(TypeError): - lookup.string_to_index_table_from_tensor( + lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], num_oov_buckets=1, hasher_spec=1) - table = lookup.string_to_index_table_from_tensor( + table = lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], num_oov_buckets=1, hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) @@ -1495,13 +1580,13 @@ class IndexToStringTest(test.TestCase): class InitializeTableFromFileOpTest(test.TestCase): - def _createVocabFile(self, basename): + def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): vocabulary_file = os.path.join(self.get_temp_dir(), basename) with open(vocabulary_file, "w") as f: - f.write("\n".join(["brain", "salad", "surgery"]) + "\n") + f.write("\n".join(values) + "\n") return vocabulary_file - def testInitializeTable(self): + def testInitializeStringTable(self): vocabulary_file = self._createVocabFile("one_column_1.txt") with self.test_session(): @@ -1514,8 +1599,27 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value) table.init.run() - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) + output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) + + result = output.eval() + self.assertAllEqual([0, 1, -1], result) + + def testInitializeInt64Table(self): + vocabulary_file = self._createVocabFile( + "one_column_int64.txt", values=("42", "1", "-1000")) + + with self.test_session(): + default_value = -1 + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.int64, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), + default_value) + table.init.run() + + output = table.lookup( + constant_op.constant((42, 1, 11), dtype=dtypes.int64)) result = output.eval() self.assertAllEqual([0, 1, -1], result) @@ -1791,17 +1895,34 @@ class InitializeTableFromFileOpTest(test.TestCase): self.assertAllEqual([0, 1, 2, -1], out.eval()) self.assertEquals(vocab_size, table.size().eval()) + def testInt64ToIdTable(self): + vocab_file = self._createVocabFile( + "feat_to_id_3.txt", values=("42", "1", "-1000")) + with self.test_session(): + default_value = -1 + vocab_size = 3 + table = lookup.HashTable( + lookup.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), + default_value) + table.init.run() + + out = table.lookup( + constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) + self.assertAllEqual((0, 1, 2, -1), out.eval()) + self.assertEquals(vocab_size, table.size().eval()) + class IdTableWithHashBucketsTest(test.TestCase): - def _createVocabFile(self, basename): + def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): vocabulary_file = os.path.join(self.get_temp_dir(), basename) with open(vocabulary_file, "w") as f: - f.write("\n".join(["brain", "salad", "surgery"]) + "\n") + f.write("\n".join(values) + "\n") return vocabulary_file - def testIdTableWithHashBucketsInit(self): - vocab_file = self._createVocabFile("feat_to_id_3.txt") + def testStringIdTableWithHashBuckets(self): + vocab_file = self._createVocabFile("feat_to_id_1.txt") with self.test_session(): default_value = -1 vocab_size = 3 @@ -1821,7 +1942,50 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertAllEqual([0, 1, 2, 3], out.eval()) self.assertEquals(vocab_size + oov_buckets, table.size().eval()) - def testIdTableWithOnlyHashBucket(self): + def testInt32IdTableWithHashBuckets(self): + vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) + with self.test_session(): + default_value = -1 + vocab_size = 3 + oov_buckets = 1 + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), + default_value), + oov_buckets, + key_dtype=dtypes.int32) + + table.init.run() + + values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) + + out = table.lookup(values) + self.assertAllEqual([0, 1, 2, 3], out.eval()) + self.assertEquals(vocab_size + oov_buckets, table.size().eval()) + + def testInt64IdTableWithHashBuckets(self): + vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) + with self.test_session(): + default_value = -1 + vocab_size = 3 + oov_buckets = 1 + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), + default_value), + oov_buckets) + + table.init.run() + + values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) + + out = table.lookup(values) + self.assertAllEqual([0, 1, 2, 3], out.eval()) + self.assertEquals(vocab_size + oov_buckets, table.size().eval()) + + def testStringIdTableWithOnlyHashBucket(self): with self.test_session(): oov_buckets = 5 @@ -1830,9 +1994,9 @@ class IdTableWithHashBucketsTest(test.TestCase): table = lookup.IdTableWithHashBuckets(None, oov_buckets) table.init.run() - input_string = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant(("brain", "salad", "surgery")) - out = table.lookup(input_string) + out = table.lookup(values) self.assertAllEqual( [ 3, # fingerprint("brain") mod 5. @@ -1842,6 +2006,40 @@ class IdTableWithHashBucketsTest(test.TestCase): out.eval()) self.assertEquals(oov_buckets, table.size().eval()) + def testInt32IdTableWithOnlyHashBucket(self): + with self.test_session(): + oov_buckets = 5 + + # Set a table that only uses hash buckets, for each input value returns + # an id calculated by fingerprint("input") mod oov_buckets. + table = lookup.IdTableWithHashBuckets( + None, oov_buckets, key_dtype=dtypes.int32) + table.init.run() + + input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32) + + out = table.lookup(input_string) + self.assertAllEqual( + [ + 1, # fingerprint("42") mod 5. + 4, # fingerprint("1") mod 5. + 2 # fingerprint("-1000") mod 5 + ], + out.eval()) + self.assertEquals(oov_buckets, table.size().eval()) + + def testFloat64IdTableWithOnlyHashBucket(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): + lookup.IdTableWithHashBuckets( + None, num_oov_buckets=5, key_dtype=dtypes.float64) + + def testBoolIdTableWithOnlyHashBucket(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): + lookup.IdTableWithHashBuckets( + None, num_oov_buckets=5, key_dtype=dtypes.bool) + def testIdTableWithHashBucketsWithMultipleInitializers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") with self.test_session() as sess: @@ -1996,6 +2194,64 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) self.assertAllEqual(input_shape, sp_ids_shape) + def testInt32SparseTensor(self): + input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] + input_shape = [4, 4] + with self.test_session() as sess: + sp_features = sparse_tensor.SparseTensor( + constant_op.constant(input_indices, dtypes.int64), + constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), + constant_op.constant(input_shape, dtypes.int64)) + + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.KeyValueTensorInitializer( + (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), + -1), + 1, + key_dtype=dtypes.int32) + table.init.run() + + sp_ids = table.lookup(sp_features) + + self.assertAllEqual([5], sp_ids.values._shape_as_list()) + + sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( + [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) + + self.assertAllEqual(input_indices, sp_ids_ind) + self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) + self.assertAllEqual(input_shape, sp_ids_shape) + + def testInt64SparseTensor(self): + input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] + input_shape = [4, 4] + with self.test_session() as sess: + sp_features = sparse_tensor.SparseTensor( + constant_op.constant(input_indices, dtypes.int64), + constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), + constant_op.constant(input_shape, dtypes.int64)) + + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.KeyValueTensorInitializer( + (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), + -1), + 1, + key_dtype=dtypes.int64) + table.init.run() + + sp_ids = table.lookup(sp_features) + + self.assertAllEqual([5], sp_ids.values._shape_as_list()) + + sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( + [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) + + self.assertAllEqual(input_indices, sp_ids_ind) + self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) + self.assertAllEqual(input_shape, sp_ids_shape) + def testIdTableWithHashBucketsWithInvalidHashers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") with self.test_session(): |