aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yleon@google.com>2017-05-11 11:56:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 12:01:19 -0700
commitf2bbf4b9e7d2559ea9667ec791843a19f0f776fc (patch)
tree611ff30ce13f447fae0bd5d03a10df020098cb2f /tensorflow/contrib/lookup
parentbe15e9eb12a2c61f9b0fef98e94967e64af1f6a1 (diff)
Move back lookup_ops from feature_column/ to contrib/lookup/ since they are using V1 kernels and we want core use the V2 kernels.
PiperOrigin-RevId: 155777403
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/BUILD12
-rw-r--r--tensorflow/contrib/lookup/__init__.py2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py1658
3 files changed, 1670 insertions, 2 deletions
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index bbbd340352..b0475c41c9 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -9,14 +9,24 @@ package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "py_test")
+# TODO(yleon): Refactor after one we switching to the V2 kernels.
py_library(
name = "lookup_py",
srcs = [
"__init__.py",
+ "lookup_ops.py",
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python/feature_column:lookup_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:lookup_ops_gen",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index a5fcdc7b42..dbd64cf042 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -47,7 +47,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
-from tensorflow.python.feature_column.lookup_ops import *
+from tensorflow.contrib.lookup.lookup_ops import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
new file mode 100644
index 0000000000..b415235b99
--- /dev/null
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -0,0 +1,1658 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Lookup table operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import functools
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.training.saver import BaseSaverBuilder
+from tensorflow.python.util import compat
+from tensorflow.python.util.deprecation import deprecated
+
+
+class LookupInterface(object):
+ """Represent a lookup table that persists across different steps."""
+
+ def __init__(self, key_dtype, value_dtype, name):
+ """Construct a lookup table interface.
+
+ Args:
+ key_dtype: The table key type.
+ value_dtype: The table value type.
+ name: A name for the operation (optional).
+ """
+ self._key_dtype = dtypes.as_dtype(key_dtype)
+ self._value_dtype = dtypes.as_dtype(value_dtype)
+ self._name = name
+
+ @property
+ def key_dtype(self):
+ """The table key dtype."""
+ return self._key_dtype
+
+ @property
+ def value_dtype(self):
+ """The table value dtype."""
+ return self._value_dtype
+
+ @property
+ def name(self):
+ """The name of the table."""
+ return self._name
+
+ @property
+ def init(self):
+ """The table initialization op."""
+ raise NotImplementedError
+
+ def size(self, name=None):
+ """Compute the number of elements in this table."""
+ raise NotImplementedError
+
+ def lookup(self, keys, name=None):
+ """Looks up `keys` in a table, outputs the corresponding values."""
+ raise NotImplementedError
+
+ def check_table_dtypes(self, key_dtype, value_dtype):
+ """Check that the given key_dtype and value_dtype matches the table dtypes.
+
+ Args:
+ key_dtype: The key data type to check.
+ value_dtype: The value data type to check.
+
+ Raises:
+ TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
+ types.
+ """
+ if key_dtype != self.key_dtype:
+ raise TypeError("Invalid key dtype, expected %s but got %s." %
+ (self.key_dtype, key_dtype))
+ if value_dtype != self.value_dtype:
+ raise TypeError("Invalid value dtype, expected %s but got %s." %
+ (self.value_dtype, value_dtype))
+
+
+class InitializableLookupTableBase(LookupInterface):
+ """Initializable lookup table interface.
+
+ An initializable lookup tables persist across different steps.
+ """
+
+ def __init__(self, table_ref, default_value, initializer):
+ """Construct a table object from a table reference.
+
+ If requires a table initializer object (subclass of `TableInitializerBase`).
+ It provides the table key and value types, as well as the op to initialize
+ the table. The caller is responsible to execute the initialization op.
+
+ Args:
+ table_ref: The table reference, i.e. the output of the lookup table ops.
+ default_value: The value to use if a key is missing in the table.
+ initializer: The table initializer to use.
+ """
+ super(InitializableLookupTableBase, self).__init__(
+ initializer.key_dtype, initializer.value_dtype,
+ table_ref.op.name.split("/")[-1])
+ self._table_ref = table_ref
+ self._default_value = ops.convert_to_tensor(default_value,
+ dtype=self._value_dtype)
+ self._default_value.get_shape().merge_with(tensor_shape.scalar())
+ self._init = initializer.initialize(self)
+
+ @property
+ def table_ref(self):
+ """Get the underlying table reference."""
+ return self._table_ref
+
+ @property
+ def default_value(self):
+ """The default value of the table."""
+ return self._default_value
+
+ @property
+ def init(self):
+ """The table initialization op."""
+ return self._init
+
+ def size(self, name=None):
+ """Compute the number of elements in this table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this table.
+ """
+ with ops.name_scope(name, "%s_Size" % self._name,
+ [self._table_ref]) as scope:
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope)
+ # pylint: enable=protected-access
+
+ def lookup(self, keys, name=None):
+ """Looks up `keys` in a table, outputs the corresponding values.
+
+ The `default_value` is used for keys not present in the table.
+
+ Args:
+ keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
+
+ Raises:
+ TypeError: when `keys` or `default_value` doesn't match the table data
+ types.
+ """
+ key_tensor = keys
+ if isinstance(keys, sparse_tensor.SparseTensor):
+ key_tensor = keys.values
+
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(
+ name, "%s_Lookup" % self._name,
+ (self._table_ref, key_tensor, self._default_value)) as scope:
+ # pylint: disable=protected-access
+ values = gen_lookup_ops._lookup_table_find(
+ self._table_ref, key_tensor, self._default_value, name=scope)
+ # pylint: enable=protected-access
+
+ values.set_shape(key_tensor.get_shape())
+ if isinstance(keys, sparse_tensor.SparseTensor):
+ return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
+ else:
+ return values
+
+
+class HashTable(InitializableLookupTableBase):
+ """A generic hash table implementation.
+
+ Example usage:
+
+ ```python
+ table = tf.contrib.lookup.HashTable(
+ tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)
+ out = table.lookup(input_tensor).
+ table.init.run()
+ print out.eval()
+ ```
+ """
+
+ def __init__(self, initializer, default_value, shared_name=None, name=None):
+ """Creates a non-initialized `HashTable` object.
+
+ Creates a table, the type of its keys and values are specified by the
+ initializer.
+ Before using the table you will have to initialize it. After initialization
+ the table will be immutable.
+
+ Args:
+ initializer: The table initializer to use. See `HashTable` kernel for
+ supported key and value types.
+ default_value: The value to use if a key is missing in the table.
+ shared_name: If non-empty, this table will be shared under
+ the given name across multiple sessions.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `HashTable` object.
+ """
+ with ops.name_scope(
+ name, "hash_table", (initializer, default_value)) as scope:
+ # pylint: disable=protected-access
+ table_ref = gen_lookup_ops._hash_table(
+ shared_name=shared_name,
+ key_dtype=initializer.key_dtype,
+ value_dtype=initializer.value_dtype,
+ name=scope)
+ # pylint: enable=protected-access
+
+ super(HashTable, self).__init__(table_ref, default_value, initializer)
+
+
+class TableInitializerBase(object):
+ """Base class for lookup table initializers."""
+
+ def __init__(self, key_dtype, value_dtype):
+ """Construct a table initializer object.
+
+ Args:
+ key_dtype: Type of the table keys.
+ value_dtype: Type of the table values.
+ """
+ self._key_dtype = dtypes.as_dtype(key_dtype)
+ self._value_dtype = dtypes.as_dtype(value_dtype)
+
+ @property
+ def key_dtype(self):
+ """The expected table key dtype."""
+ return self._key_dtype
+
+ @property
+ def value_dtype(self):
+ """The expected table value dtype."""
+ return self._value_dtype
+
+ def initialize(self, table):
+ """Returns the table initialization op."""
+ raise NotImplementedError
+
+
+class KeyValueTensorInitializer(TableInitializerBase):
+ """Table initializers given `keys` and `values` tensors."""
+
+ def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
+ """Constructs a table initializer object based on keys and values tensors.
+
+ Args:
+ keys: The tensor for the keys.
+ values: The tensor for the values.
+ key_dtype: The `keys` data type. Used when `keys` is a python array.
+ value_dtype: The `values` data type. Used when `values` is a python array.
+ name: A name for the operation (optional).
+ """
+ with ops.name_scope(name, "key_value_init", [keys, values]) as scope:
+ self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
+ self._values = ops.convert_to_tensor(values,
+ dtype=value_dtype,
+ name="values")
+ self._name = scope
+
+ super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
+ self._values.dtype)
+
+ def initialize(self, table):
+ """Initializes the given `table` with `keys` and `values` tensors.
+
+ Args:
+ table: The table to initialize.
+
+ Returns:
+ The operation that initializes the table.
+
+ Raises:
+ TypeError: when the keys and values data types do not match the table
+ key and value data types.
+ """
+ table.check_table_dtypes(self._keys.dtype, self._values.dtype)
+ with ops.name_scope(
+ self._name,
+ values=(table.table_ref, self._keys, self._values)) as scope:
+ # pylint: disable=protected-access
+ init_op = gen_lookup_ops._initialize_table(
+ table.table_ref, self._keys, self._values, name=scope)
+ # pylint: enable=protected-access
+ ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
+ return init_op
+
+
+class TextFileIndex(object):
+ WHOLE_LINE = -2
+ LINE_NUMBER = -1
+
+
+class TextFileInitializer(TableInitializerBase):
+ """Table initializers from a text file.
+
+ This initializer assigns one entry in the table for each line in the file.
+
+ The key and value type of the table to initialize is given by `key_dtype` and
+ `value_dtype`.
+
+ The key and value content to get from each line is specified by
+ the `key_index` and `value_index`.
+
+ * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
+ expects data type int64.
+ * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
+ type string.
+ * A value `>=0` means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+ For example if we have a file with the following content:
+
+ ```
+ emerson 10
+ lake 20
+ palmer 30
+ ```
+
+ The following snippet initializes a table with the first column as keys and
+ second column as values:
+
+ * `emerson -> 10`
+ * `lake -> 20`
+ * `palmer -> 30`
+
+ ```python
+ table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
+ ...
+ table.init.run()
+ ```
+
+ Similarly to initialize the whole line as keys and the line number as values.
+
+ * `emerson 10 -> 0`
+ * `lake 20 -> 1`
+ * `palmer 30 -> 2`
+
+ ```python
+ table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
+ tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
+ ...
+ table.init.run()
+ ```
+ """
+
+ def __init__(self,
+ filename,
+ key_dtype,
+ key_index,
+ value_dtype,
+ value_index,
+ vocab_size=None,
+ delimiter="\t",
+ name=None):
+ """Constructs a table initializer object to populate from a text file.
+
+ It generates one key-value pair per line. The type of table key and
+ value are specified by `key_dtype` and `value_dtype`, respectively.
+ Similarly the content of the key and value are specified by the key_index
+ and value_index.
+
+ - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
+ expects data type int64.
+ - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
+ type string.
+ - A value >=0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+ Args:
+ filename: The filename of the text file to be used for initialization.
+ The path must be accessible from wherever the graph is initialized
+ (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
+ key_dtype: The `key` data type.
+ key_index: the index that represents information of a line to get the
+ table 'key' values from.
+ value_dtype: The `value` data type.
+ value_index: the index that represents information of a line to get the
+ table 'value' values from.'
+ vocab_size: The number of elements in the file, if known.
+ delimiter: The delimiter to separate fields in a line.
+ name: A name for the operation (optional).
+
+ Raises:
+ ValueError: when the filename is empty, or when the table key and value
+ data types do not match the expected data types.
+ """
+ if not isinstance(filename, ops.Tensor) and not filename:
+ raise ValueError("Filename required for %s." % name)
+
+ key_dtype = dtypes.as_dtype(key_dtype)
+ value_dtype = dtypes.as_dtype(value_dtype)
+
+ if key_index < -2:
+ raise ValueError("Invalid key index %s." % (key_index))
+
+ if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
+ raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (dtypes.int64, key_dtype))
+ if ((key_index == TextFileIndex.WHOLE_LINE) and
+ (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
+ raise ValueError(
+ "Signature mismatch. Keys must be integer or string, got %s." %
+ key_dtype)
+ if value_index < -2:
+ raise ValueError("Invalid value index %s." % (value_index))
+
+ if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
+ raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
+ (dtypes.int64, value_dtype))
+ if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
+ raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
+ (dtypes.string, value_dtype))
+
+ if (vocab_size is not None) and (vocab_size <= 0):
+ raise ValueError("Invalid vocab_size %s." % vocab_size)
+
+ self._filename = filename
+ self._key_index = key_index
+ self._value_index = value_index
+ self._vocab_size = vocab_size
+ self._delimiter = delimiter
+ self._name = name
+
+ super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
+
+ def initialize(self, table):
+ """Initializes the table from a text file.
+
+ Args:
+ table: The table to be initialized.
+
+ Returns:
+ The operation that initializes the table.
+
+ Raises:
+ TypeError: when the keys and values data types do not match the table
+ key and value data types.
+ """
+ table.check_table_dtypes(self.key_dtype, self.value_dtype)
+ with ops.name_scope(
+ self._name, "text_file_init", (table.table_ref,)) as scope:
+ filename = ops.convert_to_tensor(self._filename,
+ dtypes.string,
+ name="asset_filepath")
+ # pylint: disable=protected-access
+ init_op = gen_lookup_ops._initialize_table_from_text_file(
+ table.table_ref,
+ filename,
+ self._key_index,
+ self._value_index,
+ -1 if self._vocab_size is None else self._vocab_size,
+ self._delimiter,
+ name=scope)
+ # pylint: enable=protected-access
+ ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
+ ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
+ return init_op
+
+
+class TextFileStringTableInitializer(TextFileInitializer):
+ """Table initializer for `int64` IDs to string tables from a text file."""
+
+ def __init__(self,
+ filename,
+ key_column_index=TextFileIndex.LINE_NUMBER,
+ value_column_index=TextFileIndex.WHOLE_LINE,
+ vocab_size=None,
+ delimiter="\t",
+ name="text_file_string_table_init"):
+ """Constructs an initializer for an id-to-string table from a text file.
+
+ It populates a table that its key and value types are int64 and string,
+ respectively. It generates one key-value pair per line.
+ The content of the key and value are specified by `key_column_index`
+ and `value_column_index`.
+
+ - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
+ expects data type int64.
+ - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
+ type string.
+ - A value >=0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+ Args:
+ filename: The filename of the text file to be used for initialization.
+ The path must be accessible from wherever the graph is initialized
+ (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
+ key_column_index: The column index from the text file to get the keys
+ from. The default is 0 that represents the whole line content.
+ value_column_index: The column index from the text file to get the
+ values from. The default is to use the line number, starting from zero.
+ vocab_size: The number of elements in the file, if known.
+ delimiter: The delimiter to separate fields in a line.
+ name: Optional name for the op.
+
+ Raises:
+ TypeError: when the filename is empty, or when the table key and value
+ data types do not match the expected data types.
+ """
+ super(TextFileStringTableInitializer, self).__init__(filename,
+ dtypes.int64,
+ key_column_index,
+ dtypes.string,
+ value_column_index,
+ vocab_size=vocab_size,
+ delimiter=delimiter,
+ name=name)
+
+
+class TextFileIdTableInitializer(TextFileInitializer):
+ """Table initializer for string to `int64` IDs tables from a text file."""
+
+ def __init__(self,
+ filename,
+ key_column_index=TextFileIndex.WHOLE_LINE,
+ value_column_index=TextFileIndex.LINE_NUMBER,
+ vocab_size=None,
+ delimiter="\t",
+ name="text_file_id_table_init",
+ key_dtype=dtypes.string):
+ """Constructs an initializer for an string-to-id table from a text file.
+
+ It populates a table that its key and value types are string and int64,
+ respectively. It generates one key-value pair per line.
+ The content of the key and value are specified by the key_index
+ and value_index.
+
+ - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
+ expects data type int64.
+ - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
+ type string.
+ - A value >=0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+ Args:
+ filename: The filename of the text file to be used for initialization.
+ The path must be accessible from wherever the graph is initialized
+ (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
+ key_column_index: The column index from the text file to get the `key`
+ values from. The default is to use the line number, starting from zero.
+ value_column_index: The column index from the text file ro get the `value`
+ values from. The default is 0 that represents the whole line content.
+ vocab_size: The number of elements in the file, if known.
+ delimiter: The delimiter to separate fields in a line.
+ name: Optional name for the op.
+ key_dtype: The `key` data type.
+
+ Raises:
+ TypeError: when the filename is empty, or when the table key and value
+ data types do not match the expected data types.
+ """
+ super(TextFileIdTableInitializer, self).__init__(filename,
+ key_dtype,
+ key_column_index,
+ dtypes.int64,
+ value_column_index,
+ vocab_size=vocab_size,
+ delimiter=delimiter,
+ name=name)
+
+
+class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
+ """A structure for the spec of the hashing function to use for hash buckets.
+
+ `hasher` is the name of the hashing function to use (eg. "fasthash",
+ "stronghash").
+ `key` is optional and specify the key to use for the hash function if
+ supported, currently only used by a strong hash.
+
+ Fields:
+ hasher: The hasher name to use.
+ key: The key to be used by the hashing function, if required.
+ """
+ __slots__ = ()
+
+
+FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name
+
+
+class StrongHashSpec(HasherSpec):
+ """A structure to specify a key of the strong keyed hash spec.
+
+ The strong hash requires a `key`, which is a list of 2 unsigned integer
+ numbers. These should be non-zero; random numbers generated from random.org
+ would be a fine choice.
+
+ Fields:
+ key: The key to be used by the keyed hashing function.
+ """
+ __slots__ = ()
+
+ def __new__(cls, key):
+ if len(key) != 2:
+ raise ValueError("key must have size 2, got %s." % len(key))
+
+ if not isinstance(key[0], compat.integral_types) or not isinstance(
+ key[1], compat.integral_types):
+ raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
+
+ return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
+
+
+def _as_string(tensor):
+ if dtypes.string == tensor.dtype.base_dtype:
+ return tensor
+ return string_ops.as_string(tensor)
+
+
+class IdTableWithHashBuckets(LookupInterface):
+ """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
+
+ For example, if an instance of `IdTableWithHashBuckets` is initialized with a
+ string-to-id table that maps:
+ - emerson -> 0
+ - lake -> 1
+ - palmer -> 2
+
+ The `IdTableWithHashBuckets` object will performs the following mapping:
+ - emerson -> 0
+ - lake -> 1
+ - palmer -> 2
+ - <other term> -> bucket id between 3 and 3 + num_oov_buckets, calculated by:
+ hash(<term>) % num_oov_buckets + vocab_size
+
+ If input_tensor is ["emerson", "lake", "palmer", "king", "crimson"],
+ the lookup result is [0, 1, 2, 4, 7]
+
+ If `table` is None, only out-of-vocabulary buckets are used.
+
+ Example usage:
+
+ ```python
+ num_oov_buckets = 3
+ input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
+ table = tf.IdTableWithHashBuckets(
+ tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),
+ num_oov_buckets)
+ out = table.lookup(input_tensor).
+ table.init.run()
+ print out.eval()
+ ```
+
+ The hash function used for generating out-of-vocabulary buckets ID is handled
+ by `hasher_spec`.
+ """
+
+ def __init__(self,
+ table,
+ num_oov_buckets,
+ hasher_spec=FastHashSpec,
+ name=None,
+ key_dtype=None):
+ """Construct a `IdTableWithHashBuckets` object.
+
+ Args:
+ table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
+ num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
+ hasher_spec: A `HasherSpec` to specify the hash function to use for
+ assignation of out-of-vocabulary buckets (optional).
+ name: A name for the operation (optional).
+ key_dtype: Data type of keys passed to `lookup`. Defaults to
+ `table.key_dtype` if `table` is specified, otherwise `tf.string`.
+ Must be string or integer, and must be castable to `table.key_dtype`.
+
+ Raises:
+ ValueError: when `table` in None and `num_oov_buckets` is not positive.
+ TypeError: when `hasher_spec` is invalid.
+ """
+ # If a name ends with a '/' it is a "name scope", remove all trailing '/'
+ # characters to use as table name.
+ if name:
+ name = name.rstrip("/")
+ if table:
+ if key_dtype is None:
+ key_dtype = table.key_dtype
+ supported_table_key_dtypes = (dtypes.int64, dtypes.string)
+ if table.key_dtype not in supported_table_key_dtypes:
+ raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
+ (supported_table_key_dtypes, key_dtype))
+ if table.key_dtype.is_integer != key_dtype.is_integer:
+ raise TypeError("Invalid key dtype, expected %s but got %s." %
+ ("integer" if key_dtype.is_integer else "non-integer",
+ table.key_dtype))
+ if table.value_dtype != dtypes.int64:
+ raise TypeError("Invalid value dtype, expected %s but got %s." %
+ (dtypes.int64, table.value_dtype))
+ self._table = table
+ name = name or self._table.name
+ else:
+ if num_oov_buckets <= 0:
+ raise ValueError("oov_buckets must be > 0 if no table is supplied.")
+ key_dtype = dtypes.string if key_dtype is None else key_dtype
+ self._table = None
+ name = name or "hash_bucket"
+ if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
+ raise TypeError(
+ "Invalid key_dtype, expected integer or string, got %s." % key_dtype)
+ self._num_oov_buckets = num_oov_buckets
+
+ if not isinstance(hasher_spec, HasherSpec):
+ raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
+ hasher_spec)
+ self._hasher_spec = hasher_spec
+ super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
+ name.split("/")[-1])
+
+ @property
+ def init(self):
+ """The table initialization op."""
+ if self._table:
+ return self._table.init
+ with ops.name_scope(None, "init"):
+ return control_flow_ops.no_op()
+
+ def size(self, name=None):
+ """Compute the number of elements in this table."""
+ with ops.name_scope(name, "%s_Size" % self.name) as scope:
+ if self._table:
+ tsize = self._table.size(scope)
+ else:
+ tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
+ return tsize + self._num_oov_buckets
+
+ def _get_string_to_hash_bucket_fn(self, hasher_spec):
+ """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
+ if not isinstance(hasher_spec, HasherSpec):
+ raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
+ if hasher_spec.hasher == "fasthash":
+ return string_ops.string_to_hash_bucket_fast
+ if hasher_spec.hasher == "legacy":
+ return string_ops.string_to_hash_bucket
+ if hasher_spec.hasher == "stronghash":
+ return functools.partial(
+ string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
+ raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
+
+ def lookup(self, keys, name=None):
+ """Looks up `keys` in the table, outputs the corresponding values.
+
+ It assigns out-of-vocabulary keys to buckets based in their hashes.
+
+ Args:
+ keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
+ name: Optional name for the op.
+
+ Returns:
+ A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
+
+ Raises:
+ TypeError: when `keys` doesn't match the table key data type.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+ values = keys
+ if isinstance(keys, sparse_tensor.SparseTensor):
+ values = keys.values
+ if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
+ values = math_ops.to_int64(values)
+
+ if self._num_oov_buckets == 0:
+ ids = self._table.lookup(values, name=name)
+ else:
+ # TODO(yleon): Consider moving this functionality to its own kernel.
+ with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
+ str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
+ self._hasher_spec)
+ buckets = str_to_hash_bucket(
+ _as_string(values),
+ num_buckets=self._num_oov_buckets,
+ name="hash_bucket")
+ if self._table:
+ ids = self._table.lookup(values)
+ buckets = math_ops.add(buckets, self._table.size())
+ is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
+ ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
+ else:
+ ids = buckets
+ if isinstance(keys, sparse_tensor.SparseTensor):
+ return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
+ return ids
+
+
+@deprecated("2017-04-10", "Use `index_table_from_file`.")
+def string_to_index_table_from_file(vocabulary_file=None,
+ num_oov_buckets=0,
+ vocab_size=None,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ name=None):
+ return index_table_from_file(
+ vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec,
+ key_dtype=dtypes.string, name=name)
+
+
+def index_table_from_file(vocabulary_file=None,
+ num_oov_buckets=0,
+ vocab_size=None,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ key_dtype=dtypes.string,
+ name=None):
+ """Returns a lookup table that converts a string tensor into int64 IDs.
+
+ This operation constructs a lookup table to convert tensor of strings into
+ int64 IDs. The mapping can be initialized from a vocabulary file specified in
+ `vocabulary_file`, where the whole line is the key and the zero-based line
+ number is the ID.
+
+ Any lookup of an out-of-vocabulary token will return a bucket ID based on its
+ hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
+ `default_value`.
+ The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`.
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` or `table.init.run()` once.
+
+ Sample Usages:
+
+ If we have a vocabulary file "test.txt" with the following content:
+
+ ```
+ emerson
+ lake
+ palmer
+ ```
+
+ ```python
+ features = tf.constant(["emerson", "lake", "and", "palmer"])
+ table = tf.contrib.lookup.index_table_from_file(
+ vocabulary_file="test.txt", num_oov_buckets=1)
+ ids = table.lookup(features)
+ ...
+ tf.tables_initializer().run()
+
+ ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket
+ ```
+
+ Args:
+ vocabulary_file: The vocabulary filename.
+ num_oov_buckets: The number of out-of-vocabulary buckets.
+ vocab_size: Number of the elements in the vocabulary, if known.
+ default_value: The value to use for out-of-vocabulary feature values.
+ Defaults to -1.
+ hasher_spec: A `HasherSpec` to specify the hash function to use for
+ assignation of out-of-vocabulary buckets.
+ key_dtype: The `key` data type.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
+
+ Raises:
+ ValueError: If `vocabulary_file` is not set.
+ ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
+ than zero.
+ """
+ if not vocabulary_file:
+ raise ValueError("vocabulary_file must be specified.")
+ if num_oov_buckets < 0:
+ raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
+ % num_oov_buckets)
+ if vocab_size is not None and vocab_size < 1:
+ raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
+ if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
+ raise TypeError("Only integer and string keys are supported.")
+
+ with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
+ table = None
+ shared_name = ""
+ with ops.name_scope(None, "hash_table") as hash_table_scope:
+ if vocab_size:
+ # Keep the shared_name:
+ # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size,
+ TextFileIndex.WHOLE_LINE,
+ TextFileIndex.LINE_NUMBER)
+ else:
+ # Keep the shared_name
+ # <table_type>_<filename>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%s_%s" % (vocabulary_file,
+ TextFileIndex.WHOLE_LINE,
+ TextFileIndex.LINE_NUMBER)
+ init = TextFileIdTableInitializer(
+ vocabulary_file, vocab_size=vocab_size,
+ key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
+ name="table_init")
+
+ table = HashTable(
+ init, default_value, shared_name=shared_name, name=hash_table_scope)
+ if num_oov_buckets:
+ table = IdTableWithHashBuckets(
+ table,
+ num_oov_buckets=num_oov_buckets,
+ hasher_spec=hasher_spec,
+ name=feat_to_id_scope,
+ key_dtype=key_dtype)
+
+ return table
+
+
+@deprecated("2017-04-10", "Use `index_table_from_tensor`.")
+def string_to_index_table_from_tensor(mapping,
+ num_oov_buckets=0,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ name=None):
+ with ops.name_scope(name, "string_to_index") as scope:
+ mapping = ops.convert_to_tensor(mapping)
+ if dtypes.string != mapping.dtype.base_dtype:
+ raise ValueError("string_to_index_table_from_tensor requires string.")
+ return index_table_from_tensor(
+ mapping, num_oov_buckets, default_value, hasher_spec, name=scope)
+
+
+def index_table_from_tensor(mapping,
+ num_oov_buckets=0,
+ default_value=-1,
+ hasher_spec=FastHashSpec,
+ dtype=dtypes.string,
+ name=None):
+ """Returns a lookup table that converts a string tensor into int64 IDs.
+
+ This operation constructs a lookup table to convert tensor of strings into
+ int64 IDs. The mapping can be initialized from a string `mapping` 1-D tensor
+ where each element is a key and corresponding index within the tensor is the
+ value.
+
+ Any lookup of an out-of-vocabulary token will return a bucket ID based on its
+ hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
+ `default_value`.
+ The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`.
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` or `table.init.run()` once.
+
+ Elements in `mapping` cannot have duplicates, otherwise when executing the
+ table initializer op, it will throw a `FailedPreconditionError`.
+
+ Sample Usages:
+
+ ```python
+ mapping_strings = t.constant(["emerson", "lake", "palmer"])
+ table = tf.contrib.lookup.index_table_from_tensor(
+ mapping=mapping_strings, num_oov_buckets=1, default_value=-1)
+ features = tf.constant(["emerson", "lake", "and", "palmer"])
+ ids = table.lookup(features)
+ ...
+ tf.tables_initializer().run()
+
+ ids.eval() ==> [0, 1, 4, 2]
+ ```
+
+ Args:
+ mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The
+ type of this object must be castable to `dtype`.
+ num_oov_buckets: The number of out-of-vocabulary buckets.
+ default_value: The value to use for out-of-vocabulary feature values.
+ Defaults to -1.
+ hasher_spec: A `HasherSpec` to specify the hash function to use for
+ assignment of out-of-vocabulary buckets.
+ dtype: The type of values passed to `lookup`. Only string and integers are
+ supported.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map an input `Tensor` to index `int64` `Tensor`.
+
+ Raises:
+ ValueError: If `mapping` is invalid.
+ ValueError: If `num_oov_buckets` is negative.
+ """
+ if mapping is None:
+ raise ValueError("mapping must be specified.")
+
+ if num_oov_buckets < 0:
+ raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
+ % num_oov_buckets)
+
+ if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
+ raise TypeError("Only integer and string keys are supported.")
+
+ with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
+ keys = ops.convert_to_tensor(mapping)
+ if keys.dtype.is_integer != dtype.is_integer:
+ raise ValueError("Expected %s, got %s." % (
+ "integer" if dtype.is_integer else "non-integer", keys.dtype))
+ if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
+ raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
+ num_elements = array_ops.size(keys)
+ values = math_ops.to_int64(math_ops.range(num_elements))
+
+ shared_name = ""
+ with ops.name_scope(None, "hash_table") as hash_table_scope:
+ table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
+ init = KeyValueTensorInitializer(
+ table_keys, values, table_keys.dtype.base_dtype, dtypes.int64,
+ name="table_init")
+ table = HashTable(
+ init, default_value, shared_name=shared_name, name=hash_table_scope)
+ if num_oov_buckets:
+ table = IdTableWithHashBuckets(
+ table,
+ num_oov_buckets=num_oov_buckets,
+ hasher_spec=hasher_spec,
+ name=feat_to_id_scope,
+ key_dtype=dtype)
+
+ return table
+
+
+@deprecated(
+ "2017-01-07", "This op will be removed after the deprecation date. "
+ "Please switch to index_table_from_tensor and call the lookup "
+ "method of the returned table.")
+def string_to_index(tensor, mapping, default_value=-1, name=None):
+ """Maps `tensor` of strings into `int64` indices based on `mapping`.
+
+ This operation converts `tensor` of strings into `int64` indices.
+ The mapping is initialized from a string `mapping` tensor where each element
+ is a key and corresponding index within the tensor is the value.
+
+ Any entry in the input which does not have a corresponding entry in 'mapping'
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ Elements in `mapping` cannot be duplicated, otherwise the initialization
+ will throw a FailedPreconditionError.
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` once.
+
+ For example:
+
+ ```python
+ mapping_strings = tf.constant(["emerson", "lake", "palmer"])
+ feats = tf.constant(["emerson", "lake", "and", "palmer"])
+ ids = tf.contrib.lookup.string_to_index(
+ feats, mapping=mapping_strings, default_value=-1)
+ ...
+ tf.tables_initializer().run()
+
+ ids.eval() ==> [0, 1, -1, 2]
+ ```
+
+ Args:
+ tensor: A 1-D input `Tensor` with the strings to map to indices.
+ mapping: A 1-D string `Tensor` that specifies the mapping of strings to
+ indices.
+ default_value: The `int64` value to use for out-of-vocabulary strings.
+ Defaults to -1.
+ name: A name for this op (optional).
+
+ Returns:
+ The mapped indices. It has the same shape and tensor type (dense or sparse)
+ as `tensor`.
+ """
+ table = index_table_from_tensor(
+ mapping=mapping, default_value=default_value, name=name)
+ return table.lookup(tensor)
+
+
+def index_to_string_table_from_file(vocabulary_file,
+ vocab_size=None,
+ default_value="UNK",
+ name=None):
+ """Returns a lookup table that maps a `Tensor` of indices into strings.
+
+ This operation constructs a lookup table to map int64 indices into string
+ values. The table is initialized from a vocabulary file specified in
+ `vocabulary_file`, where the whole line is the value and the
+ zero-based line number is the index.
+
+ Any input which does not have a corresponding index in the vocabulary file
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` or `table.init.run()` once.
+
+ Sample Usages:
+
+ If we have a vocabulary file "test.txt" with the following content:
+
+ ```
+ emerson
+ lake
+ palmer
+ ```
+
+ ```python
+ indices = tf.constant([1, 5], tf.int64)
+ table = tf.contrib.lookup.index_to_string_table_from_file(
+ vocabulary_file="test.txt", default_value="UNKNOWN")
+ values = table.lookup(indices)
+ ...
+ tf.tables_initializer().run()
+
+ values.eval() ==> ["lake", "UNKNOWN"]
+ ```
+
+ Args:
+ vocabulary_file: The vocabulary filename.
+ vocab_size: Number of the elements in the vocabulary, if known.
+ default_value: The value to use for out-of-vocabulary indices.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map a string values associated to a given index `int64`
+ `Tensors`.
+
+ Raises:
+ ValueError: when `vocabulary_file` is empty.
+ ValueError: when `vocab_size` is invalid.
+ """
+ if not vocabulary_file:
+ raise ValueError("vocabulary_file must be specified.")
+ if vocab_size is not None and vocab_size < 1:
+ raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
+
+ with ops.name_scope(name, "index_to_string") as scope:
+ shared_name = ""
+ if vocab_size:
+ # Keep a shared_name
+ # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size,
+ TextFileIndex.LINE_NUMBER,
+ TextFileIndex.WHOLE_LINE)
+ else:
+ # Keep a shared_name <table_type>_<filename>_<key_index>_<value_index>
+ shared_name = "hash_table_%s_%s_%s" % (vocabulary_file,
+ TextFileIndex.LINE_NUMBER,
+ TextFileIndex.WHOLE_LINE)
+ init = TextFileStringTableInitializer(
+ vocabulary_file, vocab_size=vocab_size, name="table_init")
+
+ # TODO(yleon): Use a more effienct structure.
+ return HashTable(init, default_value, shared_name=shared_name, name=scope)
+
+
+def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
+ """Returns a lookup table that maps a `Tensor` of indices into strings.
+
+ This operation constructs a lookup table to map int64 indices into string
+ values. The mapping is initialized from a string `mapping` 1-D `Tensor` where
+ each element is a value and the corresponding index within the tensor is the
+ key.
+
+ Any input which does not have a corresponding index in 'mapping'
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` or `table.init.run()` once.
+
+ Elements in `mapping` cannot have duplicates, otherwise when executing the
+ table initializer op, it will throw a `FailedPreconditionError`.
+
+ Sample Usages:
+
+ ```python
+ mapping_string = t.constant(["emerson", "lake", "palmer")
+ indices = tf.constant([1, 5], tf.int64)
+ table = tf.contrib.lookup.index_to_string_table_from_tensor(
+ mapping_string, default_value="UNKNOWN")
+ values = table.lookup(indices)
+ ...
+ tf.tables_initializer().run()
+
+ values.eval() ==> ["lake", "UNKNOWN"]
+ ```
+
+ Args:
+ mapping: A 1-D string `Tensor` that specifies the strings to map from
+ indices.
+ default_value: The value to use for out-of-vocabulary indices.
+ name: A name for this op (optional).
+
+ Returns:
+ The lookup table to map a string values associated to a given index `int64`
+ `Tensors`.
+
+ Raises:
+ ValueError: when `mapping` is not set.
+ """
+
+ if mapping is None:
+ raise ValueError("mapping must be specified.")
+
+ with ops.name_scope(name, "index_to_string") as scope:
+ values = ops.convert_to_tensor(mapping, dtypes.string)
+ num_elements = array_ops.size(values)
+ keys = math_ops.to_int64(math_ops.range(num_elements))
+
+ shared_name = ""
+ init = KeyValueTensorInitializer(
+ keys, values, dtypes.int64, dtypes.string, name="table_init")
+ # TODO(yleon): Use a more effienct structure.
+ return HashTable(init, default_value, shared_name=shared_name, name=scope)
+
+
+@deprecated(
+ "2017-01-07", "This op will be removed after the deprecation date. "
+ "Please switch to index_to_string_table_from_tensor and call the lookup "
+ "method of the returned table.")
+def index_to_string(tensor, mapping, default_value="UNK", name=None):
+ """Maps `tensor` of indices into string values based on `mapping`.
+
+ This operation converts `int64` indices into string values. The mapping is
+ initialized from a string `mapping` tensor where each element is a value and
+ the corresponding index within the tensor is the key.
+
+ Any input which does not have a corresponding index in 'mapping'
+ (an out-of-vocabulary entry) is assigned the `default_value`
+
+ The underlying table must be initialized by calling
+ `tf.tables_initializer.run()` once.
+
+ For example:
+
+ ```python
+ mapping_string = t.constant(["emerson", "lake", "palmer")
+ indices = tf.constant([1, 5], tf.int64)
+ values = tf.contrib.lookup.index_to_string(
+ indices, mapping=mapping_string, default_value="UNKNOWN")
+ ...
+ tf.tables_initializer().run()
+
+ values.eval() ==> ["lake", "UNKNOWN"]
+ ```
+
+ Args:
+ tensor: A `int64` `Tensor` with the indices to map to strings.
+ mapping: A 1-D string `Tensor` that specifies the strings to map from
+ indices.
+ default_value: The string value to use for out-of-vocabulary indices.
+ name: A name for this op (optional).
+
+ Returns:
+ The strings values associated to the indices. The resultant dense
+ feature value tensor has the same shape as the corresponding `indices`.
+ """
+ table = index_to_string_table_from_tensor(
+ mapping=mapping, default_value=default_value, name=name)
+ return table.lookup(tensor)
+
+
+class MutableHashTable(LookupInterface):
+ """A generic mutable hash table implementation.
+
+ Data can be inserted by calling the insert method. It does not support
+ initialization via the init method.
+
+ Example usage:
+
+ ```python
+ table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string,
+ value_dtype=tf.int64,
+ default_value=-1)
+ table.insert(keys, values)
+ out = table.lookup(query_keys)
+ print out.eval()
+ ```
+ """
+
+ def __init__(self,
+ key_dtype,
+ value_dtype,
+ default_value,
+ shared_name=None,
+ name="MutableHashTable",
+ checkpoint=True):
+ """Creates an empty `MutableHashTable` object.
+
+ Creates a table, the type of its keys and values are specified by key_dtype
+ and value_dtype, respectively.
+
+ Args:
+ key_dtype: the type of the key tensors.
+ value_dtype: the type of the value tensors.
+ default_value: The value to use if a key is missing in the table.
+ shared_name: If non-empty, this table will be shared under
+ the given name across multiple sessions.
+ name: A name for the operation (optional).
+ checkpoint: if True, the contents of the table are saved to and restored
+ from checkpoints. If `shared_name` is empty for a checkpointed table, it
+ is shared using the table node name.
+
+ Returns:
+ A `MutableHashTable` object.
+
+ Raises:
+ ValueError: If checkpoint is True and no name was specified.
+ """
+ self._default_value = ops.convert_to_tensor(default_value,
+ dtype=value_dtype)
+ self._value_shape = self._default_value.get_shape()
+
+ # The table must be shared if checkpointing is requested for multi-worker
+ # training to work correctly. Use the node name if no shared_name has been
+ # explicitly specified.
+ use_node_name_sharing = checkpoint and shared_name is None
+ # pylint: disable=protected-access
+ if self._default_value.get_shape().ndims == 0:
+ self._table_ref = gen_lookup_ops._mutable_hash_table(
+ shared_name=shared_name,
+ use_node_name_sharing=use_node_name_sharing,
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ name=name)
+ else:
+ self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors(
+ shared_name=shared_name,
+ use_node_name_sharing=use_node_name_sharing,
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ value_shape=self._default_value.get_shape(),
+ name=name)
+ # pylint: enable=protected-access
+ super(MutableHashTable, self).__init__(key_dtype, value_dtype,
+ self._table_ref.op.name.split(
+ "/")[-1])
+
+ if checkpoint:
+ saveable = MutableHashTable._Saveable(self, name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+
+ def size(self, name=None):
+ """Compute the number of elements in this table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this table.
+ """
+ with ops.name_scope(name, "%s_Size" % self._name,
+ [self._table_ref]) as name:
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
+
+ def lookup(self, keys, name=None):
+ """Looks up `keys` in a table, outputs the corresponding values.
+
+ The `default_value` is used for keys not present in the table.
+
+ Args:
+ keys: Keys to look up. Can be a tensor of any shape. Must match the
+ table's key_dtype.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor containing the values in the same shape as `keys` using the
+ table's value type.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_find" % self._name,
+ (self._table_ref, keys, self._default_value)) as name:
+ # pylint: disable=protected-access
+ values = gen_lookup_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
+
+ values.set_shape(keys.get_shape().concatenate(self._value_shape))
+ return values
+
+ def insert(self, keys, values, name=None):
+ """Associates `keys` with `values`.
+
+ Args:
+ keys: Keys to insert. Can be a tensor of any shape. Must match the
+ table's key type.
+ values: Values to be associated with keys. Must be a tensor of the same
+ shape as `keys` and match the table's value type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` or `values` doesn't match the table data
+ types.
+ """
+ self.check_table_dtypes(keys.dtype, values.dtype)
+ with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
+ [self._table_ref, keys, values]) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops._lookup_table_insert(
+ self._table_ref, keys, values, name=name)
+ return op
+
+ def export(self, name=None):
+ """Returns tensors of all keys and values in the table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A pair of tensors with the first tensor containing all keys and the
+ second tensors containing all values in the table.
+ """
+ with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
+ [self._table_ref]) as name:
+ # pylint: disable=protected-access
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
+
+ exported_values.set_shape(exported_keys.get_shape().concatenate(
+ self._value_shape))
+ return exported_keys, exported_values
+
+ class _Saveable(BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for MutableHashTable."""
+
+ def __init__(self, table, name):
+ tensors = table.export()
+ specs = [
+ BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
+ BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
+ ]
+ # pylint: disable=protected-access
+ super(MutableHashTable._Saveable, self).__init__(table, specs, name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_import(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])
+
+
+class MutableDenseHashTable(LookupInterface):
+ """A generic mutable hash table implementation using tensors as backing store.
+
+ Data can be inserted by calling the insert method. It does not support
+ initialization via the init method.
+
+ It uses "open addressing" with quadratic reprobing to resolve collisions.
+ Compared to `MutableHashTable` the insert and lookup operations in a
+ `MutableDenseHashTable` are typically faster, but memory usage can be higher.
+ However, `MutableDenseHashTable` does not require additional memory for
+ temporary tensors created during checkpointing and restore operations.
+
+ Example usage:
+
+ ```python
+ table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64,
+ value_dtype=tf.int64,
+ default_value=-1,
+ empty_key=0)
+ table.insert(keys, values)
+ out = table.lookup(query_keys)
+ print out.eval()
+ ```
+ """
+
+ # TODO(andreasst): consider extracting common code with MutableHashTable into
+ # a common superclass.
+ def __init__(self,
+ key_dtype,
+ value_dtype,
+ default_value,
+ empty_key,
+ initial_num_buckets=None,
+ shared_name=None,
+ name="MutableDenseHashTable",
+ checkpoint=True):
+ """Creates an empty `MutableDenseHashTable` object.
+
+ Creates a table, the type of its keys and values are specified by key_dtype
+ and value_dtype, respectively.
+
+ Args:
+ key_dtype: the type of the key tensors.
+ value_dtype: the type of the value tensors.
+ default_value: The value to use if a key is missing in the table.
+ empty_key: the key to use to represent empty buckets internally. Must not
+ be used in insert or lookup operations.
+ initial_num_buckets: the initial number of buckets.
+ shared_name: If non-empty, this table will be shared under
+ the given name across multiple sessions.
+ name: A name for the operation (optional).
+ checkpoint: if True, the contents of the table are saved to and restored
+ from checkpoints. If `shared_name` is empty for a checkpointed table, it
+ is shared using the table node name.
+
+ Returns:
+ A `MutableHashTable` object.
+
+ Raises:
+ ValueError: If checkpoint is True and no name was specified.
+ """
+ self._default_value = ops.convert_to_tensor(
+ default_value, dtype=value_dtype)
+ self._value_shape = self._default_value.get_shape()
+
+ # The table must be shared if checkpointing is requested for multi-worker
+ # training to work correctly. Use the node name if no shared_name has been
+ # explicitly specified.
+ use_node_name_sharing = checkpoint and shared_name is None
+ empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
+ # pylint: disable=protected-access
+ self._table_ref = gen_lookup_ops._mutable_dense_hash_table(
+ empty_key=empty_key,
+ shared_name=shared_name,
+ use_node_name_sharing=use_node_name_sharing,
+ value_dtype=value_dtype,
+ value_shape=self._value_shape,
+ initial_num_buckets=initial_num_buckets,
+ name=name)
+ # pylint: enable=protected-access
+ super(MutableDenseHashTable, self).__init__(
+ key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])
+
+ if checkpoint:
+ saveable = MutableDenseHashTable._Saveable(self, name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+
+ def size(self, name=None):
+ """Compute the number of elements in this table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this table.
+ """
+ with ops.name_scope(name, "%s_Size" % self._name,
+ [self._table_ref]) as name:
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
+
+ def lookup(self, keys, name=None):
+ """Looks up `keys` in a table, outputs the corresponding values.
+
+ The `default_value` is used for keys not present in the table.
+
+ Args:
+ keys: Keys to look up. Can be a tensor of any shape. Must match the
+ table's key_dtype.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor containing the values in the same shape as `keys` using the
+ table's value type.
+
+ Raises:
+ TypeError: when `keys` do not match the table data types.
+ """
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
+ (self._key_dtype, keys.dtype))
+
+ with ops.name_scope(name, "%s_lookup_table_find" % self._name,
+ [self._table_ref, keys]) as name:
+ # pylint: disable=protected-access
+ values = gen_lookup_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
+
+ if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
+ values.set_shape(
+ tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
+ self._value_shape))
+ return values
+
+ def insert(self, keys, values, name=None):
+ """Associates `keys` with `values`.
+
+ Args:
+ keys: Keys to insert. Can be a tensor of any shape. Must match the
+ table's key type.
+ values: Values to be associated with keys. Must be a tensor of the same
+ shape as `keys` and match the table's value type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+
+ Raises:
+ TypeError: when `keys` or `values` doesn't match the table data
+ types.
+ """
+ self.check_table_dtypes(keys.dtype, values.dtype)
+ with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
+ [self._table_ref, keys, values]) as name:
+ # pylint: disable=protected-access
+ op = gen_lookup_ops._lookup_table_insert(
+ self._table_ref, keys, values, name=name)
+ return op
+
+ def export(self, name=None):
+ """Returns tensors of all keys and values in the table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A pair of tensors with the first tensor containing all keys and the
+ second tensors containing all values in the table.
+ """
+ with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
+ [self._table_ref]) as name:
+ # pylint: disable=protected-access
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
+
+ exported_values.set_shape(exported_keys.get_shape().concatenate(
+ self._value_shape))
+ return exported_keys, exported_values
+
+ class _Saveable(BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for MutableDenseHashTable."""
+
+ def __init__(self, table, name):
+ tensors = table.export()
+ specs = [
+ BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
+ BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
+ ]
+ # pylint: disable=protected-access
+ super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_import(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])