diff options
author | Russell Power <power@google.com> | 2018-01-04 12:02:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-04 12:06:05 -0800 |
commit | 782519a152c81873878e30c7791ccff5f6f534d1 (patch) | |
tree | 2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0 /tensorflow/core/kernels/save_restore_tensor.h | |
parent | dd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff) |
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism.
Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds.
PiperOrigin-RevId: 180827583
Diffstat (limited to 'tensorflow/core/kernels/save_restore_tensor.h')
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.h | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h index 1e87e5c30b..5b74b586e8 100644 --- a/tensorflow/core/kernels/save_restore_tensor.h +++ b/tensorflow/core/kernels/save_restore_tensor.h @@ -37,18 +37,21 @@ void SaveTensors( checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, bool save_slices); -// Reads a tensor from the reader built from open_func() and produces it as -// context->output(0). "preferred_shard" is the same the TensorSliceReader -// preferred_shard parameter. +// Reads a single tensor from the reader built from open_func() and produces +// it as context->output(restore_index). "preferred_shard" is the same the +// TensorSliceReader preferred_shard parameter. // // context must have the following inputs: // 0: a single element string tensor that contains the file name. -// 1: a single element string tensor that names the output to be restored. +// 1: string tensor that names the outputs to be restored. // If restore_slice is true: -// 2: shape and slice specification of the tensor to restore. +// 2: shape and slice specification of the tensors to restore. +// +// restore_index indicates the variable name and slice to lookup +// in context(1) and (2). void RestoreTensor(OpKernelContext* context, checkpoint::TensorSliceReader::OpenTableFunction open_func, - int preferred_shard, bool restore_slice); + int preferred_shard, bool restore_slice, int restore_index); // V2 checkpoint format. |