diff options
author | 2018-01-04 12:02:14 -0800 | |
---|---|---|
committer | 2018-01-04 12:06:05 -0800 | |
commit | 782519a152c81873878e30c7791ccff5f6f534d1 (patch) | |
tree | 2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0 /tensorflow/python/training/saver.py | |
parent | dd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff) |
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism.
Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds.
PiperOrigin-RevId: 180827583
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r-- | tensorflow/python/training/saver.py | 94 |
1 files changed, 72 insertions, 22 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2330229d56..2c59b82ebe 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -241,6 +241,34 @@ class BaseSaverBuilder(object): else: raise RuntimeError("Unexpected write_version: " + self._write_version) + def bulk_restore(self, filename_tensor, saveables, preferred_shard, + restore_sequentially): + """Restore all tensors contained in saveables. + + By default, this issues separate calls to `restore_op` for each saveable. + Subclasses may override to load multiple saveables in a single call. + + Args: + filename_tensor: String Tensor. + saveables: List of BaseSaverBuilder.SaveableObject objects. + preferred_shard: Int. Shard to open first when loading a sharded file. + restore_sequentially: Bool. If true, each restore is sequential. + + Returns: + A list of Tensors resulting from reading 'saveable' from + 'filename'. + + """ + all_tensors = [] + assign_ops = [] + for saveable in saveables: + restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] + with ops.device(_set_cpu0(saveable.device) if saveable.device else None): + with ops.control_dependencies(restore_control_inputs): + all_tensors.extend( + self.restore_op(filename_tensor, saveable, preferred_shard)) + return all_tensors + # pylint: disable=unused-argument def restore_op(self, filename_tensor, saveable, preferred_shard): """Create ops to restore 'saveable'. @@ -416,30 +444,32 @@ class BaseSaverBuilder(object): Returns: An Operation that restores the variables. """ + all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, + restore_sequentially) + assign_ops = [] + idx = 0 + # Load and optionally reshape on the CPU, as string tensors are not + # available on the GPU. + # TODO(touts): Re-enable restore on GPU when we can support annotating + # string tensors as "HostMemory" inputs. for saveable in saveables: - restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] - # Load and optionally reshape on the CPU, as string tensors are not - # available on the GPU. - # TODO(touts): Re-enable restore on GPU when we can support annotating - # string tensors as "HostMemory" inputs. - with ops.device(_set_cpu0(saveable.device) if saveable.device else None): - with ops.control_dependencies(restore_control_inputs): - tensors = self.restore_op(filename_tensor, saveable, preferred_shard) - shapes = None - if reshape: - # Compute the shapes, let the restore op decide if and how to do - # the reshape. - shapes = [] - for spec in saveable.specs: - v = spec.tensor - shape = v.get_shape() - if not shape.is_fully_defined(): - shape = array_ops.shape(v) - shapes.append(shape) - assign_ops.append(saveable.restore(tensors, shapes)) - - # Create a Noop that has control dependencies from all the updates. + shapes = None + if reshape: + # Compute the shapes, let the restore op decide if and how to do + # the reshape. + shapes = [] + for spec in saveable.specs: + v = spec.tensor + shape = v.get_shape() + if not shape.is_fully_defined(): + shape = array_ops.shape(v) + shapes.append(shape) + saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] + idx += len(saveable.specs) + assign_ops.append(saveable.restore(saveable_tensors, shapes)) + + # Create a Noop that has control dependencies from all the updates. return control_flow_ops.group(*assign_ops, name=name) def _AddShardedRestoreOps(self, filename_tensor, per_device, @@ -797,6 +827,25 @@ class BaseSaverBuilder(object): version=self._write_version) +class BulkSaverBuilder(BaseSaverBuilder): + """SaverBuilder with support for bulk restoring multiple saveables.""" + + def bulk_restore(self, filename_tensor, saveables, preferred_shard, + restore_sequentially): + + # Ignored: bulk restore is internally sequential. + del restore_sequentially + restore_specs = [] + for saveable in saveables: + for spec in saveable.specs: + restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype)) + + names, slices, dtypes = zip(*restore_specs) + # Load all tensors onto CPU 0 for compatibility with existing code. + with ops.device("cpu:0"): + return io_ops.restore_v2(filename_tensor, names, slices, dtypes) + + def _get_saver_or_default(): """Returns the saver from SAVERS collection, or creates a default one. @@ -1261,6 +1310,7 @@ class Saver(object): if not self.saver_def or context.in_eager_mode(): if self._builder is None: self._builder = BaseSaverBuilder(self._write_version) + if self._var_list is None: # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() |