aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-20 10:46:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 10:49:31 -0700
commita5ba38b40215ee68aa34b19204e182162bfcb04d (patch)
treeed84b988be12b6865abe83050ac2cb66a8306903 /tensorflow/contrib/lookup
parent601c58a057c0488d8ba5ec38a13345a70606bb67 (diff)
Object-based checkpointing+eager support for mutable hash tables
Small eager execution fixes: omits op names, and sets a unique shared_name for each table by default to prevent automatic sharing (since there's no op name uniquification). There are some TODOs about the shared_names, since eager execution's kernel caching will cache a new kernel for each shared_name. Only an issue if tables are created in a loop (and variables have the same issue; still on my list to fix that too). PiperOrigin-RevId: 209445076
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py50
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py108
2 files changed, 151 insertions, 7 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 291972cce3..f83765a48d 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_lookup_ops
@@ -39,6 +42,7 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex
from tensorflow.python.ops.lookup_ops import TextFileInitializer
from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer
# pylint: enable=unused-import
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.deprecation import deprecated
@@ -285,7 +289,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
return table.lookup(tensor)
-class MutableHashTable(LookupInterface):
+class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation.
Data can be inserted by calling the insert method. It does not support
@@ -336,6 +340,13 @@ class MutableHashTable(LookupInterface):
dtype=value_dtype)
self._value_shape = self._default_value.get_shape()
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
@@ -355,9 +366,12 @@ class MutableHashTable(LookupInterface):
value_dtype=value_dtype,
value_shape=self._default_value.get_shape(),
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableHashTable, self).__init__(key_dtype, value_dtype,
- self._table_ref.op.name.split(
- "/")[-1])
+ op_name)
if checkpoint:
saveable = MutableHashTable._Saveable(self, name)
@@ -446,6 +460,10 @@ class MutableHashTable(LookupInterface):
self._table_ref, self._key_dtype, self._value_dtype, name=name)
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableHashTable."""
@@ -458,14 +476,15 @@ class MutableHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
self.op._table_ref, restored_tensors[0], restored_tensors[1])
-class MutableDenseHashTable(LookupInterface):
+class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation using tensors as backing store.
Data can be inserted by calling the insert method. It does not support
@@ -536,6 +555,13 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -544,8 +570,12 @@ class MutableDenseHashTable(LookupInterface):
value_shape=self._value_shape,
initial_num_buckets=initial_num_buckets,
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableDenseHashTable, self).__init__(
- key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])
+ key_dtype, value_dtype, op_name)
if checkpoint:
saveable = MutableDenseHashTable._Saveable(self, name)
@@ -636,6 +666,11 @@ class MutableDenseHashTable(LookupInterface):
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(
+ MutableDenseHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableDenseHashTable."""
@@ -648,7 +683,8 @@ class MutableDenseHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 81257e1de5..f9b0358a36 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
+from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
@@ -383,6 +384,59 @@ class MutableHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(20.0, name="v1")
+
+ default_val = -1
+ keys = constant_op.constant(["b", "c", "d"], dtypes.string)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+ self.evaluate([v0.initializer, v1.initializer])
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(0, self.evaluate(table.size()))
+ self.evaluate(table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ save_path = checkpoint.save(save_prefix)
+ del table, checkpoint, v0, v1
+
+ v0 = variables.Variable(-1.0, name="v0")
+ v1 = variables.Variable(-1.0, name="v1")
+ default_val = -1
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+ self.evaluate(table.insert(
+ constant_op.constant(["a", "c"], dtypes.string),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(table.size()))
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+
+ # Restore the saved values in the parameter nodes.
+ checkpoint.restore(save_path).run_restore_ops()
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ input_string = constant_op.constant(["a", "b", "c", "d", "e"],
+ dtypes.string)
+ output = table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testSharing(self):
# Start a server to store the table state
server = server_lib.Server(
@@ -1010,6 +1064,60 @@ class MutableDenseHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ default_value = -1
+ empty_key = 0
+ keys = constant_op.constant([11, 12, 13], dtypes.int64)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ save_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=32)
+
+ save_checkpoint = checkpointable.Checkpoint(table=save_table)
+
+ self.assertAllEqual(0, self.evaluate(save_table.size()))
+ self.evaluate(save_table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(save_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(save_table.export()[0])))
+
+ save_path = save_checkpoint.save(save_prefix)
+ del save_table, save_checkpoint
+
+ load_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=64)
+ self.evaluate(load_table.insert(
+ constant_op.constant([11, 14], dtypes.int64),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(load_table.size()))
+ self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
+
+ restore_checkpoint = checkpointable.Checkpoint(table=load_table)
+
+ # Restore the saved values in the parameter nodes.
+ restore_checkpoint.restore(save_path).run_restore_ops()
+
+ self.assertAllEqual(3, self.evaluate(load_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(load_table.export()[0])))
+
+ input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
+ output = load_table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testVectorSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")