aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-10 09:33:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 09:50:17 -0800
commitde5f2aa8194eef27c647c369178d4821574b2622 (patch)
treea01f776a5f67911874f0bdc5a2e716981ba2cc1a /tensorflow/contrib/lookup
parent3af39a00d151bd55b69c8b045a6e67284c22c9f5 (diff)
Support integer sparse feature values.
Add some additional values args to name_scope calls. Change: 149765769
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/__init__.py4
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py161
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py360
3 files changed, 434 insertions, 91 deletions
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index e743832e80..dbd64cf042 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# TODO(ptucker): deprecate string_to_index_table_from_file and
+# string_to_index_table_from_tensor 2017-04-10.
"""Ops for lookup operations.
@@string_to_index
@@string_to_index_table_from_file
@@string_to_index_table_from_tensor
+@@index_table_from_file
+@@index_table_from_tensor
@@index_to_string
@@index_to_string_table_from_file
@@index_to_string_table_from_tensor
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 5b4f5cee2d..6a20ee4440 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -178,8 +178,9 @@ class InitializableLookupTableBase(LookupInterface):
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
- with ops.name_scope(name, "%s_Lookup" % self._name,
- [self._table_ref]) as scope:
+ with ops.name_scope(
+ name, "%s_Lookup" % self._name,
+ (self._table_ref, key_tensor, self._default_value)) as scope:
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(
self._table_ref, key_tensor, self._default_value, name=scope)
@@ -215,7 +216,8 @@ class HashTable(InitializableLookupTableBase):
the table will be immutable.
Args:
- initializer: The table initializer to use.
+ initializer: The table initializer to use. See `HashTable` kernel for
+ supported key and value types.
default_value: The value to use if a key is missing in the table.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
@@ -224,7 +226,8 @@ class HashTable(InitializableLookupTableBase):
Returns:
A `HashTable` object.
"""
- with ops.name_scope(name, "hash_table", [initializer]) as scope:
+ with ops.name_scope(
+ name, "hash_table", (initializer, default_value)) as scope:
# pylint: disable=protected-access
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name,
@@ -301,7 +304,9 @@ class KeyValueTensorInitializer(TableInitializerBase):
key and value data types.
"""
table.check_table_dtypes(self._keys.dtype, self._values.dtype)
- with ops.name_scope(self._name, values=[table]) as scope:
+ with ops.name_scope(
+ self._name,
+ values=(table.table_ref, self._keys, self._values)) as scope:
# pylint: disable=protected-access
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
@@ -422,9 +427,11 @@ class TextFileInitializer(TableInitializerBase):
if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
(dtypes.int64, key_dtype))
- if key_index == TextFileIndex.WHOLE_LINE and key_dtype != dtypes.string:
- raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
- (dtypes.string, key_dtype))
+ if ((key_index == TextFileIndex.WHOLE_LINE) and
+ (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
+ raise ValueError(
+ "Signature mismatch. Keys must be integer or string, got %s." %
+ key_dtype)
if value_index < -2:
raise ValueError("Invalid value index %s." % (value_index))
@@ -461,7 +468,8 @@ class TextFileInitializer(TableInitializerBase):
key and value data types.
"""
table.check_table_dtypes(self.key_dtype, self.value_dtype)
- with ops.name_scope(self._name, "text_file_init", [table]) as scope:
+ with ops.name_scope(
+ self._name, "text_file_init", (table.table_ref,)) as scope:
filename = ops.convert_to_tensor(self._filename,
dtypes.string,
name="asset_filepath")
@@ -539,7 +547,8 @@ class TextFileIdTableInitializer(TextFileInitializer):
value_column_index=TextFileIndex.LINE_NUMBER,
vocab_size=None,
delimiter="\t",
- name="text_file_id_table_init"):
+ name="text_file_id_table_init",
+ key_dtype=dtypes.string):
"""Constructs an initializer for an string-to-id table from a text file.
It populates a table that its key and value types are string and int64,
@@ -565,13 +574,14 @@ class TextFileIdTableInitializer(TextFileInitializer):
vocab_size: The number of elements in the file, if known.
delimiter: The delimiter to separate fields in a line.
name: Optional name for the op.
+ key_dtype: The `key` data type.
Raises:
TypeError: when the filename is empty, or when the table key and value
data types do not match the expected data types.
"""
super(TextFileIdTableInitializer, self).__init__(filename,
- dtypes.string,
+ key_dtype,
key_column_index,
dtypes.int64,
value_column_index,
@@ -621,6 +631,12 @@ class StrongHashSpec(HasherSpec):
return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
+def _as_string(tensor):
+ if dtypes.string == tensor.dtype.base_dtype:
+ return tensor
+ return string_ops.as_string(tensor)
+
+
class IdTableWithHashBuckets(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
@@ -663,15 +679,19 @@ class IdTableWithHashBuckets(LookupInterface):
table,
num_oov_buckets,
hasher_spec=FastHashSpec,
- name=None):
+ name=None,
+ key_dtype=None):
"""Construct a `IdTableWithHashBuckets` object.
Args:
- table: Table that maps string to ids.
+ table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets (optional).
name: A name for the operation (optional).
+ key_dtype: Data type of keys passed to `lookup`. Defaults to
+ `table.key_dtype` if `table` is specified, otherwise `tf.string`.
+ Must be string or integer, and must be castable to `table.key_dtype`.
Raises:
ValueError: when `table` in None and `num_oov_buckets` is not positive.
@@ -682,22 +702,37 @@ class IdTableWithHashBuckets(LookupInterface):
if name:
name = name.rstrip("/")
if table:
- table.check_table_dtypes(dtypes.string, dtypes.int64)
+ if key_dtype is None:
+ key_dtype = table.key_dtype
+ supported_table_key_dtypes = (dtypes.int64, dtypes.string)
+ if table.key_dtype not in supported_table_key_dtypes:
+ raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
+ (supported_table_key_dtypes, key_dtype))
+ if table.key_dtype.is_integer != key_dtype.is_integer:
+ raise TypeError("Invalid key dtype, expected %s but got %s." %
+ ("integer" if key_dtype.is_integer else "non-integer",
+ table.key_dtype))
+ if table.value_dtype != dtypes.int64:
+ raise TypeError("Invalid value dtype, expected %s but got %s." %
+ (dtypes.int64, table.value_dtype))
self._table = table
name = name or self._table.name
else:
if num_oov_buckets <= 0:
raise ValueError("oov_buckets must be > 0 if no table is supplied.")
+ key_dtype = dtypes.string if key_dtype is None else key_dtype
self._table = None
name = name or "hash_bucket"
+ if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
+ raise TypeError(
+ "Invalid key_dtype, expected integer or string, got %s." % key_dtype)
self._num_oov_buckets = num_oov_buckets
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
hasher_spec)
self._hasher_spec = hasher_spec
-
- super(IdTableWithHashBuckets, self).__init__(dtypes.string, dtypes.int64,
+ super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
name.split("/")[-1])
@property
@@ -748,24 +783,25 @@ class IdTableWithHashBuckets(LookupInterface):
if keys.dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
-
- string_values = keys
+ values = keys
if isinstance(keys, sparse_tensor.SparseTensor):
- string_values = keys.values
+ values = keys.values
+ if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
+ values = math_ops.to_int64(values)
if self._num_oov_buckets == 0:
- ids = self._table.lookup(string_values, name=name)
+ ids = self._table.lookup(values, name=name)
else:
# TODO(yleon): Consider moving this functionality to its own kernel.
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
buckets = str_to_hash_bucket(
- string_values,
+ _as_string(values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
if self._table:
- ids = self._table.lookup(string_values)
+ ids = self._table.lookup(values)
buckets = math_ops.add(buckets, self._table.size())
is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
@@ -776,12 +812,25 @@ class IdTableWithHashBuckets(LookupInterface):
return ids
+@deprecated("2017-04-10", "Use `index_table_from_file`.")
def string_to_index_table_from_file(vocabulary_file=None,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
hasher_spec=FastHashSpec,
name=None):
+ return index_table_from_file(
+ vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec,
+ key_dtype=dtypes.string, name=name)
+
+
+def index_table_from_file(vocabulary_file=None,
+ num_oov_buckets=0,
+ vocab_size=None,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ key_dtype=dtypes.string,
+ name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
@@ -809,7 +858,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.string_to_index_table_from_file(
+ table = tf.contrib.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
@@ -826,6 +875,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets.
+ key_dtype: The `key` data type.
name: A name for this op (optional).
Returns:
@@ -843,6 +893,8 @@ def string_to_index_table_from_file(vocabulary_file=None,
% num_oov_buckets)
if vocab_size is not None and vocab_size < 1:
raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
+ if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
+ raise TypeError("Only integer and string keys are supported.")
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
table = None
@@ -861,7 +913,9 @@ def string_to_index_table_from_file(vocabulary_file=None,
TextFileIndex.WHOLE_LINE,
TextFileIndex.LINE_NUMBER)
init = TextFileIdTableInitializer(
- vocabulary_file, vocab_size=vocab_size, name="table_init")
+ vocabulary_file, vocab_size=vocab_size,
+ key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
+ name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
@@ -870,16 +924,32 @@ def string_to_index_table_from_file(vocabulary_file=None,
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
- name=feat_to_id_scope)
+ name=feat_to_id_scope,
+ key_dtype=key_dtype)
return table
+@deprecated("2017-04-10", "Use `index_table_from_tensor`.")
def string_to_index_table_from_tensor(mapping,
num_oov_buckets=0,
default_value=-1,
hasher_spec=FastHashSpec,
name=None):
+ with ops.name_scope(name, "string_to_index") as scope:
+ mapping = ops.convert_to_tensor(mapping)
+ if dtypes.string != mapping.dtype.base_dtype:
+ raise ValueError("string_to_index_table_from_tensor requires string.")
+ return index_table_from_tensor(
+ mapping, num_oov_buckets, default_value, hasher_spec, name=scope)
+
+
+def index_table_from_tensor(mapping,
+ num_oov_buckets=0,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ dtype=dtypes.string,
+ name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
@@ -902,7 +972,7 @@ def string_to_index_table_from_tensor(mapping,
```python
mapping_strings = t.constant(["emerson", "lake", "palmer")
- table = tf.contrib.lookup.string_to_index_table_from_tensor(
+ table = tf.contrib.lookup.index_table_from_tensor(
mapping=mapping_strings, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
@@ -913,20 +983,22 @@ def string_to_index_table_from_tensor(mapping,
```
Args:
- mapping: A 1-D string `Tensor` that specifies the mapping of strings to
- indices.
+ mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The
+ type of this object must be castable to `dtype`.
num_oov_buckets: The number of out-of-vocabulary buckets.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
- assignation of out-of-vocabulary buckets.
+ assignment of out-of-vocabulary buckets.
+ dtype: The type of values passed to `lookup`. Only string and integers are
+ supported.
name: A name for this op (optional).
Returns:
- The lookup table to map a string `Tensor` to index `int64` `Tensor`.
+ The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
- ValueError: `mapping` is invalid.
+ ValueError: If `mapping` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if mapping is None:
@@ -936,15 +1008,25 @@ def string_to_index_table_from_tensor(mapping,
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)
+ if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
+ raise TypeError("Only integer and string keys are supported.")
+
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
- keys = ops.convert_to_tensor(mapping, dtypes.string)
+ keys = ops.convert_to_tensor(mapping)
+ if keys.dtype.is_integer != dtype.is_integer:
+ raise ValueError("Expected %s, got %s." % (
+ "integer" if dtype.is_integer else "non-integer", keys.dtype))
+ if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
+ raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
num_elements = array_ops.size(keys)
- values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
+ values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
+ table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
- keys, values, dtypes.string, dtypes.int64, name="table_init")
+ table_keys, values, table_keys.dtype.base_dtype, dtypes.int64,
+ name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
if num_oov_buckets:
@@ -952,14 +1034,15 @@ def string_to_index_table_from_tensor(mapping,
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
- name=feat_to_id_scope)
+ name=feat_to_id_scope,
+ key_dtype=dtype)
return table
@deprecated(
"2017-01-07", "This op will be removed after the deprecation date. "
- "Please switch to string_to_index_table_from_tensor and call the lookup "
+ "Please switch to index_table_from_tensor and call the lookup "
"method of the returned table.")
def string_to_index(tensor, mapping, default_value=-1, name=None):
"""Maps `tensor` of strings into `int64` indices based on `mapping`.
@@ -1002,7 +1085,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
The mapped indices. It has the same shape and tensor type (dense or sparse)
as `tensor`.
"""
- table = string_to_index_table_from_tensor(
+ table = index_table_from_tensor(
mapping=mapping, default_value=default_value, name=name)
return table.lookup(tensor)
@@ -1135,7 +1218,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
with ops.name_scope(name, "index_to_string") as scope:
values = ops.convert_to_tensor(mapping, dtypes.string)
num_elements = array_ops.size(values)
- keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
+ keys = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
init = KeyValueTensorInitializer(
@@ -1306,7 +1389,7 @@ class MutableHashTable(LookupInterface):
(self._key_dtype, keys.dtype))
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
- [self._table_ref, keys]) as name:
+ (self._table_ref, keys, self._default_value)) as name:
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(self._table_ref,
keys,
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 3b38ecab97..fe8fa71981 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1150,18 +1150,18 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table2.size().eval())
-class StringToIndexTableFromFile(test.TestCase):
+class IndexTableFromFile(test.TestCase):
- def _createVocabFile(self, basename):
+ def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
- f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
+ f.write("\n".join(values) + "\n")
return vocabulary_file
- def test_string_to_index_table_from_file(self):
+ def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1169,11 +1169,39 @@ class StringToIndexTableFromFile(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
- def test_string_to_index_table_from_file_with_default_value(self):
+ def test_int32_index_table_from_file(self):
+ vocabulary_file = self._createVocabFile(
+ "f2i_vocab2.txt", values=("42", "1", "-1000"))
+ with self.test_session():
+ table = lookup.index_table_from_file(
+ vocabulary_file=vocabulary_file, num_oov_buckets=1,
+ key_dtype=dtypes.int32)
+ ids = table.lookup(
+ constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ data_flow_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_int64_index_table_from_file(self):
+ vocabulary_file = self._createVocabFile(
+ "f2i_vocab3.txt", values=("42", "1", "-1000"))
+ with self.test_session():
+ table = lookup.index_table_from_file(
+ vocabulary_file=vocabulary_file, num_oov_buckets=1,
+ key_dtype=dtypes.int64)
+ ids = table.lookup(
+ constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ data_flow_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_index_table_from_file_with_default_value(self):
default_value = -42
- vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
+ vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1181,10 +1209,10 @@ class StringToIndexTableFromFile(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
- def test_string_to_index_table_from_file_with_oov_buckets(self):
- vocabulary_file = self._createVocabFile("f2i_vocab3.txt")
+ def test_index_table_from_file_with_oov_buckets(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
@@ -1199,16 +1227,16 @@ class StringToIndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
- def test_string_to_index_table_from_file_with_only_oov_buckets(self):
+ def test_index_table_from_file_with_only_oov_buckets(self):
self.assertRaises(
ValueError,
- lookup.string_to_index_table_from_file,
+ lookup.index_table_from_file,
vocabulary_file=None)
- def test_string_to_index_table_from_file_with_vocab_size_too_small(self):
- vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
+ def test_index_table_from_file_with_vocab_size_too_small(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1217,25 +1245,25 @@ class StringToIndexTableFromFile(test.TestCase):
self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval())
- def test_string_to_index_table_from_file_with_vocab_size_too_large(self):
- vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
+ def test_index_table_from_file_with_vocab_size_too_large(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", table.init.run)
- def test_string_to_index_table_from_file_with_vocab_size(self):
- vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
+ def test_index_table_from_file_with_vocab_size(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab8.txt")
self.assertRaises(
ValueError,
- lookup.string_to_index_table_from_file,
+ lookup.index_table_from_file,
vocabulary_file=vocabulary_file,
vocab_size=0)
with self.test_session():
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1244,17 +1272,17 @@ class StringToIndexTableFromFile(test.TestCase):
self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval())
- def test_string_to_index_table_from_file_with_invalid_hashers(self):
+ def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
with self.test_session():
with self.assertRaises(TypeError):
- lookup.string_to_index_table_from_file(
+ lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=3,
num_oov_buckets=1,
hasher_spec=1)
- table = lookup.string_to_index_table_from_file(
+ table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=3,
num_oov_buckets=1,
@@ -1264,22 +1292,70 @@ class StringToIndexTableFromFile(test.TestCase):
constant_op.constant(["salad", "surgery", "tarkus"]))
-class StringToIndexTableFromTensor(test.TestCase):
+class KeyValueTensorInitializerTest(test.TestCase):
- def test_string_to_index_table_from_tensor_with_tensor_init(self):
+ def test_string(self):
+ with ops.Graph().as_default(), self.test_session():
+ init = lookup.KeyValueTensorInitializer(
+ ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
+ table = lookup.HashTable(init, default_value=-1)
+ table.init.run()
+
+ def test_int64(self):
+ with ops.Graph().as_default(), self.test_session():
+ init = lookup.KeyValueTensorInitializer(
+ (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
+ table = lookup.HashTable(init, default_value=-1)
+ table.init.run()
+
+ def test_int32(self):
+ with ops.Graph().as_default(), self.test_session():
+ init = lookup.KeyValueTensorInitializer(
+ (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
+ table = lookup.HashTable(init, default_value=-1)
+ with self.assertRaisesRegexp(
+ errors_impl.OpError, "No OpKernel was registered"):
+ table.init.run()
+
+
+class IndexTableFromTensor(test.TestCase):
+
+ def test_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
- table = lookup.string_to_index_table_from_tensor(
- mapping=["brain", "salad", "surgery"], num_oov_buckets=1)
- ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
+ table = lookup.index_table_from_tensor(
+ mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
+ ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ data_flow_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_int32_index_table_from_tensor_with_tensor_init(self):
+ with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
+ ids = table.lookup(
+ constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ data_flow_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
+ def test_int64_index_table_from_tensor_with_tensor_init(self):
+ with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
+ ids = table.lookup(
+ constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
- def test_string_to_index_table_from_tensor_with_default_value(self):
+ def test_index_table_from_tensor_with_default_value(self):
default_value = -42
with self.test_session():
- table = lookup.string_to_index_table_from_tensor(
+ table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1287,21 +1363,30 @@ class StringToIndexTableFromTensor(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
- def test_string_to_index_table_from_tensor_with_only_oov_buckets(self):
+ def test_index_table_from_tensor_missing_mapping(self):
with self.test_session():
- with self.assertRaises(ValueError):
- lookup.string_to_index_table_from_tensor(
- mapping=None, num_oov_buckets=1)
+ with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
+ lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
+
+ def test_index_table_from_tensor_empty_mapping(self):
+ with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
+ ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ with self.assertRaisesRegexp(
+ errors_impl.OpError, "keys and values cannot be empty"):
+ data_flow_ops.tables_initializer().run()
- def test_string_to_index_table_from_tensor_with_invalid_hashers(self):
+ def test_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session():
with self.assertRaises(TypeError):
- lookup.string_to_index_table_from_tensor(
+ lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
num_oov_buckets=1,
hasher_spec=1)
- table = lookup.string_to_index_table_from_tensor(
+ table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
num_oov_buckets=1,
hasher_spec=lookup.HasherSpec("my-awesome-hash", None))
@@ -1495,13 +1580,13 @@ class IndexToStringTest(test.TestCase):
class InitializeTableFromFileOpTest(test.TestCase):
- def _createVocabFile(self, basename):
+ def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
- f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
+ f.write("\n".join(values) + "\n")
return vocabulary_file
- def testInitializeTable(self):
+ def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
with self.test_session():
@@ -1514,8 +1599,27 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
table.init.run()
- input_string = constant_op.constant(["brain", "salad", "tank"])
- output = table.lookup(input_string)
+ output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
+
+ result = output.eval()
+ self.assertAllEqual([0, 1, -1], result)
+
+ def testInitializeInt64Table(self):
+ vocabulary_file = self._createVocabFile(
+ "one_column_int64.txt", values=("42", "1", "-1000"))
+
+ with self.test_session():
+ default_value = -1
+ table = lookup.HashTable(
+ lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
+ lookup.TextFileIndex.WHOLE_LINE,
+ dtypes.int64,
+ lookup.TextFileIndex.LINE_NUMBER),
+ default_value)
+ table.init.run()
+
+ output = table.lookup(
+ constant_op.constant((42, 1, 11), dtype=dtypes.int64))
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
@@ -1791,17 +1895,34 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, 2, -1], out.eval())
self.assertEquals(vocab_size, table.size().eval())
+ def testInt64ToIdTable(self):
+ vocab_file = self._createVocabFile(
+ "feat_to_id_3.txt", values=("42", "1", "-1000"))
+ with self.test_session():
+ default_value = -1
+ vocab_size = 3
+ table = lookup.HashTable(
+ lookup.TextFileIdTableInitializer(
+ vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
+ default_value)
+ table.init.run()
+
+ out = table.lookup(
+ constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
+ self.assertAllEqual((0, 1, 2, -1), out.eval())
+ self.assertEquals(vocab_size, table.size().eval())
+
class IdTableWithHashBucketsTest(test.TestCase):
- def _createVocabFile(self, basename):
+ def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
- f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
+ f.write("\n".join(values) + "\n")
return vocabulary_file
- def testIdTableWithHashBucketsInit(self):
- vocab_file = self._createVocabFile("feat_to_id_3.txt")
+ def testStringIdTableWithHashBuckets(self):
+ vocab_file = self._createVocabFile("feat_to_id_1.txt")
with self.test_session():
default_value = -1
vocab_size = 3
@@ -1821,7 +1942,50 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out.eval())
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
- def testIdTableWithOnlyHashBucket(self):
+ def testInt32IdTableWithHashBuckets(self):
+ vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
+ with self.test_session():
+ default_value = -1
+ vocab_size = 3
+ oov_buckets = 1
+ table = lookup.IdTableWithHashBuckets(
+ lookup.HashTable(
+ lookup.TextFileIdTableInitializer(
+ vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
+ default_value),
+ oov_buckets,
+ key_dtype=dtypes.int32)
+
+ table.init.run()
+
+ values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
+
+ out = table.lookup(values)
+ self.assertAllEqual([0, 1, 2, 3], out.eval())
+ self.assertEquals(vocab_size + oov_buckets, table.size().eval())
+
+ def testInt64IdTableWithHashBuckets(self):
+ vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
+ with self.test_session():
+ default_value = -1
+ vocab_size = 3
+ oov_buckets = 1
+ table = lookup.IdTableWithHashBuckets(
+ lookup.HashTable(
+ lookup.TextFileIdTableInitializer(
+ vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
+ default_value),
+ oov_buckets)
+
+ table.init.run()
+
+ values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
+
+ out = table.lookup(values)
+ self.assertAllEqual([0, 1, 2, 3], out.eval())
+ self.assertEquals(vocab_size + oov_buckets, table.size().eval())
+
+ def testStringIdTableWithOnlyHashBucket(self):
with self.test_session():
oov_buckets = 5
@@ -1830,9 +1994,9 @@ class IdTableWithHashBucketsTest(test.TestCase):
table = lookup.IdTableWithHashBuckets(None, oov_buckets)
table.init.run()
- input_string = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant(("brain", "salad", "surgery"))
- out = table.lookup(input_string)
+ out = table.lookup(values)
self.assertAllEqual(
[
3, # fingerprint("brain") mod 5.
@@ -1842,6 +2006,40 @@ class IdTableWithHashBucketsTest(test.TestCase):
out.eval())
self.assertEquals(oov_buckets, table.size().eval())
+ def testInt32IdTableWithOnlyHashBucket(self):
+ with self.test_session():
+ oov_buckets = 5
+
+ # Set a table that only uses hash buckets, for each input value returns
+ # an id calculated by fingerprint("input") mod oov_buckets.
+ table = lookup.IdTableWithHashBuckets(
+ None, oov_buckets, key_dtype=dtypes.int32)
+ table.init.run()
+
+ input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32)
+
+ out = table.lookup(input_string)
+ self.assertAllEqual(
+ [
+ 1, # fingerprint("42") mod 5.
+ 4, # fingerprint("1") mod 5.
+ 2 # fingerprint("-1000") mod 5
+ ],
+ out.eval())
+ self.assertEquals(oov_buckets, table.size().eval())
+
+ def testFloat64IdTableWithOnlyHashBucket(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
+ lookup.IdTableWithHashBuckets(
+ None, num_oov_buckets=5, key_dtype=dtypes.float64)
+
+ def testBoolIdTableWithOnlyHashBucket(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
+ lookup.IdTableWithHashBuckets(
+ None, num_oov_buckets=5, key_dtype=dtypes.bool)
+
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session() as sess:
@@ -1996,6 +2194,64 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
self.assertAllEqual(input_shape, sp_ids_shape)
+ def testInt32SparseTensor(self):
+ input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
+ input_shape = [4, 4]
+ with self.test_session() as sess:
+ sp_features = sparse_tensor.SparseTensor(
+ constant_op.constant(input_indices, dtypes.int64),
+ constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
+ constant_op.constant(input_shape, dtypes.int64))
+
+ table = lookup.IdTableWithHashBuckets(
+ lookup.HashTable(
+ lookup.KeyValueTensorInitializer(
+ (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64),
+ -1),
+ 1,
+ key_dtype=dtypes.int32)
+ table.init.run()
+
+ sp_ids = table.lookup(sp_features)
+
+ self.assertAllEqual([5], sp_ids.values._shape_as_list())
+
+ sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
+ [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
+
+ self.assertAllEqual(input_indices, sp_ids_ind)
+ self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
+ self.assertAllEqual(input_shape, sp_ids_shape)
+
+ def testInt64SparseTensor(self):
+ input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
+ input_shape = [4, 4]
+ with self.test_session() as sess:
+ sp_features = sparse_tensor.SparseTensor(
+ constant_op.constant(input_indices, dtypes.int64),
+ constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
+ constant_op.constant(input_shape, dtypes.int64))
+
+ table = lookup.IdTableWithHashBuckets(
+ lookup.HashTable(
+ lookup.KeyValueTensorInitializer(
+ (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64),
+ -1),
+ 1,
+ key_dtype=dtypes.int64)
+ table.init.run()
+
+ sp_ids = table.lookup(sp_features)
+
+ self.assertAllEqual([5], sp_ids.values._shape_as_list())
+
+ sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
+ [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
+
+ self.assertAllEqual(input_indices, sp_ids_ind)
+ self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
+ self.assertAllEqual(input_shape, sp_ids_shape)
+
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session():