diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-20 10:46:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 10:49:31 -0700 |
commit | a5ba38b40215ee68aa34b19204e182162bfcb04d (patch) | |
tree | ed84b988be12b6865abe83050ac2cb66a8306903 /tensorflow/contrib/lookup | |
parent | 601c58a057c0488d8ba5ec38a13345a70606bb67 (diff) |
Object-based checkpointing+eager support for mutable hash tables
Small eager execution fixes: omits op names, and sets a unique shared_name for each table by default to prevent automatic sharing (since there's no op name uniquification).
There are some TODOs about the shared_names, since eager execution's kernel caching will cache a new kernel for each shared_name. Only an issue if tables are created in a loop (and variables have the same issue; still on my list to fix that too).
PiperOrigin-RevId: 209445076
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 50 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 108 |
2 files changed, 151 insertions, 7 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 291972cce3..f83765a48d 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_lookup_ops @@ -39,6 +42,7 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex from tensorflow.python.ops.lookup_ops import TextFileInitializer from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer # pylint: enable=unused-import +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.deprecation import deprecated @@ -285,7 +289,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) -class MutableHashTable(LookupInterface): +class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation. Data can be inserted by calling the insert method. It does not support @@ -336,6 +340,13 @@ class MutableHashTable(LookupInterface): dtype=value_dtype) self._value_shape = self._default_value.get_shape() + executing_eagerly = context.executing_eagerly() + if executing_eagerly and shared_name is None: + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "table_%d" % (ops.uid(),) # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. @@ -355,9 +366,12 @@ class MutableHashTable(LookupInterface): value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) + if executing_eagerly: + op_name = None + else: + op_name = self._table_ref.op.name.split("/")[-1] super(MutableHashTable, self).__init__(key_dtype, value_dtype, - self._table_ref.op.name.split( - "/")[-1]) + op_name) if checkpoint: saveable = MutableHashTable._Saveable(self, name) @@ -446,6 +460,10 @@ class MutableHashTable(LookupInterface): self._table_ref, self._key_dtype, self._value_dtype, name=name) return exported_keys, exported_values + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return {"table": functools.partial(MutableHashTable._Saveable, table=self)} + class _Saveable(BaseSaverBuilder.SaveableObject): """SaveableObject implementation for MutableHashTable.""" @@ -458,14 +476,15 @@ class MutableHashTable(LookupInterface): # pylint: disable=protected-access super(MutableHashTable._Saveable, self).__init__(table, specs, name) - def restore(self, restored_tensors, unused_restored_shapes): + def restore(self, restored_tensors, restored_shapes): + del restored_shapes # unused # pylint: disable=protected-access with ops.colocate_with(self.op._table_ref): return gen_lookup_ops.lookup_table_import_v2( self.op._table_ref, restored_tensors[0], restored_tensors[1]) -class MutableDenseHashTable(LookupInterface): +class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): """A generic mutable hash table implementation using tensors as backing store. Data can be inserted by calling the insert method. It does not support @@ -536,6 +555,13 @@ class MutableDenseHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor( empty_key, dtype=key_dtype, name="empty_key") + executing_eagerly = context.executing_eagerly() + if executing_eagerly and shared_name is None: + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "table_%d" % (ops.uid(),) self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, @@ -544,8 +570,12 @@ class MutableDenseHashTable(LookupInterface): value_shape=self._value_shape, initial_num_buckets=initial_num_buckets, name=name) + if executing_eagerly: + op_name = None + else: + op_name = self._table_ref.op.name.split("/")[-1] super(MutableDenseHashTable, self).__init__( - key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1]) + key_dtype, value_dtype, op_name) if checkpoint: saveable = MutableDenseHashTable._Saveable(self, name) @@ -636,6 +666,11 @@ class MutableDenseHashTable(LookupInterface): return exported_keys, exported_values + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return {"table": functools.partial( + MutableDenseHashTable._Saveable, table=self)} + class _Saveable(BaseSaverBuilder.SaveableObject): """SaveableObject implementation for MutableDenseHashTable.""" @@ -648,7 +683,8 @@ class MutableDenseHashTable(LookupInterface): # pylint: disable=protected-access super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name) - def restore(self, restored_tensors, unused_restored_shapes): + def restore(self, restored_tensors, restored_shapes): + del restored_shapes # unused # pylint: disable=protected-access with ops.colocate_with(self.op._table_ref): return gen_lookup_ops.lookup_table_import_v2( diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 81257e1de5..f9b0358a36 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver from tensorflow.python.training import server_lib +from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): @@ -383,6 +384,59 @@ class MutableHashTableOpTest(test.TestCase): output = table.lookup(input_string) self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) + @test_util.run_in_graph_and_eager_modes + def testObjectSaveRestore(self): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + v0 = variables.Variable(10.0, name="v0") + v1 = variables.Variable(20.0, name="v1") + + default_val = -1 + keys = constant_op.constant(["b", "c", "d"], dtypes.string) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup.MutableHashTable( + dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) + + checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) + self.evaluate([v0.initializer, v1.initializer]) + + # Check that the parameter nodes have been initialized. + self.assertEqual(10.0, self.evaluate(v0)) + self.assertEqual(20.0, self.evaluate(v1)) + + self.assertAllEqual(0, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + save_path = checkpoint.save(save_prefix) + del table, checkpoint, v0, v1 + + v0 = variables.Variable(-1.0, name="v0") + v1 = variables.Variable(-1.0, name="v1") + default_val = -1 + table = lookup.MutableHashTable( + dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) + self.evaluate(table.insert( + constant_op.constant(["a", "c"], dtypes.string), + constant_op.constant([12, 24], dtypes.int64))) + self.assertAllEqual(2, self.evaluate(table.size())) + + checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) + + # Restore the saved values in the parameter nodes. + checkpoint.restore(save_path).run_restore_ops() + # Check that the parameter nodes have been restored. + self.assertEqual(10.0, self.evaluate(v0)) + self.assertEqual(20.0, self.evaluate(v1)) + + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant(["a", "b", "c", "d", "e"], + dtypes.string) + output = table.lookup(input_string) + self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) + def testSharing(self): # Start a server to store the table state server = server_lib.Server( @@ -1010,6 +1064,60 @@ class MutableDenseHashTableOpTest(test.TestCase): output = table.lookup(input_string) self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) + @test_util.run_in_graph_and_eager_modes + def testObjectSaveRestore(self): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + default_value = -1 + empty_key = 0 + keys = constant_op.constant([11, 12, 13], dtypes.int64) + values = constant_op.constant([0, 1, 2], dtypes.int64) + save_table = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=empty_key, + name="t1", + checkpoint=True, + initial_num_buckets=32) + + save_checkpoint = checkpointable.Checkpoint(table=save_table) + + self.assertAllEqual(0, self.evaluate(save_table.size())) + self.evaluate(save_table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(save_table.size())) + self.assertAllEqual(32, len(self.evaluate(save_table.export()[0]))) + + save_path = save_checkpoint.save(save_prefix) + del save_table, save_checkpoint + + load_table = lookup.MutableDenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=empty_key, + name="t1", + checkpoint=True, + initial_num_buckets=64) + self.evaluate(load_table.insert( + constant_op.constant([11, 14], dtypes.int64), + constant_op.constant([12, 24], dtypes.int64))) + self.assertAllEqual(2, self.evaluate(load_table.size())) + self.assertAllEqual(64, len(self.evaluate(load_table.export()[0]))) + + restore_checkpoint = checkpointable.Checkpoint(table=load_table) + + # Restore the saved values in the parameter nodes. + restore_checkpoint.restore(save_path).run_restore_ops() + + self.assertAllEqual(3, self.evaluate(load_table.size())) + self.assertAllEqual(32, len(self.evaluate(load_table.export()[0]))) + + input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) + output = load_table.lookup(input_string) + self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) + def testVectorSaveRestore(self): save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") |