aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver.py
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-01-04 12:02:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-04 12:06:05 -0800
commit782519a152c81873878e30c7791ccff5f6f534d1 (patch)
tree2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0 /tensorflow/python/training/saver.py
parentdd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff)
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism. Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds. PiperOrigin-RevId: 180827583
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r--tensorflow/python/training/saver.py94
1 files changed, 72 insertions, 22 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2330229d56..2c59b82ebe 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -241,6 +241,34 @@ class BaseSaverBuilder(object):
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
+ def bulk_restore(self, filename_tensor, saveables, preferred_shard,
+ restore_sequentially):
+ """Restore all tensors contained in saveables.
+
+ By default, this issues separate calls to `restore_op` for each saveable.
+ Subclasses may override to load multiple saveables in a single call.
+
+ Args:
+ filename_tensor: String Tensor.
+ saveables: List of BaseSaverBuilder.SaveableObject objects.
+ preferred_shard: Int. Shard to open first when loading a sharded file.
+ restore_sequentially: Bool. If true, each restore is sequential.
+
+ Returns:
+ A list of Tensors resulting from reading 'saveable' from
+ 'filename'.
+
+ """
+ all_tensors = []
+ assign_ops = []
+ for saveable in saveables:
+ restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
+ with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
+ with ops.control_dependencies(restore_control_inputs):
+ all_tensors.extend(
+ self.restore_op(filename_tensor, saveable, preferred_shard))
+ return all_tensors
+
# pylint: disable=unused-argument
def restore_op(self, filename_tensor, saveable, preferred_shard):
"""Create ops to restore 'saveable'.
@@ -416,30 +444,32 @@ class BaseSaverBuilder(object):
Returns:
An Operation that restores the variables.
"""
+ all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
+ restore_sequentially)
+
assign_ops = []
+ idx = 0
+ # Load and optionally reshape on the CPU, as string tensors are not
+ # available on the GPU.
+ # TODO(touts): Re-enable restore on GPU when we can support annotating
+ # string tensors as "HostMemory" inputs.
for saveable in saveables:
- restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
- # Load and optionally reshape on the CPU, as string tensors are not
- # available on the GPU.
- # TODO(touts): Re-enable restore on GPU when we can support annotating
- # string tensors as "HostMemory" inputs.
- with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
- with ops.control_dependencies(restore_control_inputs):
- tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
- shapes = None
- if reshape:
- # Compute the shapes, let the restore op decide if and how to do
- # the reshape.
- shapes = []
- for spec in saveable.specs:
- v = spec.tensor
- shape = v.get_shape()
- if not shape.is_fully_defined():
- shape = array_ops.shape(v)
- shapes.append(shape)
- assign_ops.append(saveable.restore(tensors, shapes))
-
- # Create a Noop that has control dependencies from all the updates.
+ shapes = None
+ if reshape:
+ # Compute the shapes, let the restore op decide if and how to do
+ # the reshape.
+ shapes = []
+ for spec in saveable.specs:
+ v = spec.tensor
+ shape = v.get_shape()
+ if not shape.is_fully_defined():
+ shape = array_ops.shape(v)
+ shapes.append(shape)
+ saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
+ idx += len(saveable.specs)
+ assign_ops.append(saveable.restore(saveable_tensors, shapes))
+
+ # Create a Noop that has control dependencies from all the updates.
return control_flow_ops.group(*assign_ops, name=name)
def _AddShardedRestoreOps(self, filename_tensor, per_device,
@@ -797,6 +827,25 @@ class BaseSaverBuilder(object):
version=self._write_version)
+class BulkSaverBuilder(BaseSaverBuilder):
+ """SaverBuilder with support for bulk restoring multiple saveables."""
+
+ def bulk_restore(self, filename_tensor, saveables, preferred_shard,
+ restore_sequentially):
+
+ # Ignored: bulk restore is internally sequential.
+ del restore_sequentially
+ restore_specs = []
+ for saveable in saveables:
+ for spec in saveable.specs:
+ restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype))
+
+ names, slices, dtypes = zip(*restore_specs)
+ # Load all tensors onto CPU 0 for compatibility with existing code.
+ with ops.device("cpu:0"):
+ return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
+
+
def _get_saver_or_default():
"""Returns the saver from SAVERS collection, or creates a default one.
@@ -1261,6 +1310,7 @@ class Saver(object):
if not self.saver_def or context.in_eager_mode():
if self._builder is None:
self._builder = BaseSaverBuilder(self._write_version)
+
if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()