aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/save_restore_tensor.h
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-01-04 12:02:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-04 12:06:05 -0800
commit782519a152c81873878e30c7791ccff5f6f534d1 (patch)
tree2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0 /tensorflow/core/kernels/save_restore_tensor.h
parentdd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff)
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism. Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds. PiperOrigin-RevId: 180827583
Diffstat (limited to 'tensorflow/core/kernels/save_restore_tensor.h')
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.h15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h
index 1e87e5c30b..5b74b586e8 100644
--- a/tensorflow/core/kernels/save_restore_tensor.h
+++ b/tensorflow/core/kernels/save_restore_tensor.h
@@ -37,18 +37,21 @@ void SaveTensors(
checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
bool save_slices);
-// Reads a tensor from the reader built from open_func() and produces it as
-// context->output(0). "preferred_shard" is the same the TensorSliceReader
-// preferred_shard parameter.
+// Reads a single tensor from the reader built from open_func() and produces
+// it as context->output(restore_index). "preferred_shard" is the same the
+// TensorSliceReader preferred_shard parameter.
//
// context must have the following inputs:
// 0: a single element string tensor that contains the file name.
-// 1: a single element string tensor that names the output to be restored.
+// 1: string tensor that names the outputs to be restored.
// If restore_slice is true:
-// 2: shape and slice specification of the tensor to restore.
+// 2: shape and slice specification of the tensors to restore.
+//
+// restore_index indicates the variable name and slice to lookup
+// in context(1) and (2).
void RestoreTensor(OpKernelContext* context,
checkpoint::TensorSliceReader::OpenTableFunction open_func,
- int preferred_shard, bool restore_slice);
+ int preferred_shard, bool restore_slice, int restore_index);
// V2 checkpoint format.