aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-01-04 12:02:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-04 12:06:05 -0800
commit782519a152c81873878e30c7791ccff5f6f534d1 (patch)
tree2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0
parentdd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff)
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism. Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds. PiperOrigin-RevId: 180827583
-rw-r--r--tensorflow/core/kernels/restore_op.cc4
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc84
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.h15
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc10
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc48
-rw-r--r--tensorflow/python/training/saver.py94
6 files changed, 139 insertions, 116 deletions
diff --git a/tensorflow/core/kernels/restore_op.cc b/tensorflow/core/kernels/restore_op.cc
index 0593a07b80..d9bbcb14ab 100644
--- a/tensorflow/core/kernels/restore_op.cc
+++ b/tensorflow/core/kernels/restore_op.cc
@@ -41,7 +41,7 @@ class RestoreOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
- preferred_shard_, false);
+ preferred_shard_, false, 0);
}
private:
@@ -67,7 +67,7 @@ class RestoreSliceOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
- preferred_shard_, true);
+ preferred_shard_, true, 0);
}
private:
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index 6b06cf650a..1700bcfca5 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/kernels/save_restore_tensor.h"
+#include <numeric>
#include <unordered_map>
-
#include <utility>
#include <vector>
-#include "tensorflow/core/kernels/save_restore_tensor.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -85,7 +85,17 @@ void SaveTensors(
Status s;
auto tensor_names_flat = tensor_names_t.flat<string>();
- for (int i = 0; i < N; ++i) {
+ // Process tensors in sorted name order. This allows us to avoid seeking
+ // during restoration in the common case where we are restoring a full
+ // checkpoint.
+ std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
+ std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
+ std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
+ [&tensor_names_flat](size_t a, size_t b) {
+ return tensor_names_flat(a) < tensor_names_flat(b);
+ });
+
+ for (size_t i : sorted_name_idx) {
const string& name = tensor_names_flat(i);
const Tensor& input = context->input(i + kFixedInputs);
TensorShape shape(input.shape());
@@ -132,7 +142,7 @@ void SaveTensors(
void RestoreTensor(OpKernelContext* context,
checkpoint::TensorSliceReader::OpenTableFunction open_func,
- int preferred_shard, bool restore_slice) {
+ int preferred_shard, bool restore_slice, int restore_index) {
const Tensor& file_pattern_t = context->input(0);
{
const int64 size = file_pattern_t.NumElements();
@@ -145,26 +155,7 @@ void RestoreTensor(OpKernelContext* context,
const string& file_pattern = file_pattern_t.flat<string>()(0);
const Tensor& tensor_name_t = context->input(1);
- {
- const int64 size = tensor_name_t.NumElements();
- OP_REQUIRES(
- context, size == 1,
- errors::InvalidArgument(
- "Input 1 (tensor_name) must be a string scalar; got a tensor of ",
- size, "elements"));
- }
- const string& tensor_name = tensor_name_t.flat<string>()(0);
-
- const string* tensor_shape_and_slice_ptr = nullptr;
- if (restore_slice) {
- const Tensor& tensor_shape_and_slice_t = context->input(2);
- OP_REQUIRES(
- context, tensor_shape_and_slice_t.NumElements() == 1,
- errors::InvalidArgument("Expected 1 element for the tensor "
- "shape and slice but got ",
- tensor_shape_and_slice_t.NumElements()));
- tensor_shape_and_slice_ptr = tensor_shape_and_slice_t.flat<string>().data();
- }
+ const string& tensor_name = tensor_name_t.flat<string>()(restore_index);
// If we cannot find a cached reader we will allocate our own.
std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
@@ -187,7 +178,7 @@ void RestoreTensor(OpKernelContext* context,
errors::NotFound("Tensor name \"", tensor_name,
"\" not found in checkpoint files ", file_pattern));
OP_REQUIRES(
- context, type == context->expected_output_dtype(0),
+ context, type == context->expected_output_dtype(restore_index),
errors::InvalidArgument("Expected to restore a tensor of type ",
DataTypeString(context->expected_output_dtype(0)),
", got a tensor of type ", DataTypeString(type),
@@ -196,23 +187,26 @@ void RestoreTensor(OpKernelContext* context,
// Shape of the output and slice to load.
TensorShape output_shape(saved_shape);
TensorSlice slice_to_load(saved_shape.dims());
- if (restore_slice && !tensor_shape_and_slice_ptr[0].empty()) {
- const string& shape_spec = tensor_shape_and_slice_ptr[0];
- TensorShape parsed_shape;
- OP_REQUIRES_OK(
- context, checkpoint::ParseShapeAndSlice(shape_spec, &parsed_shape,
- &slice_to_load, &output_shape));
- OP_REQUIRES(
- context, parsed_shape.IsSameSize(saved_shape),
- errors::InvalidArgument(
- "Shape in shape_and_slice spec does not match the shape in the "
- "save file: ",
- parsed_shape.DebugString(), ", save file shape: ",
- saved_shape.DebugString()));
+ if (restore_slice) {
+ const string& shape_spec = context->input(2).flat<string>()(restore_index);
+ if (!shape_spec.empty()) {
+ TensorShape parsed_shape;
+ OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
+ shape_spec, &parsed_shape, &slice_to_load,
+ &output_shape));
+ OP_REQUIRES(
+ context, parsed_shape.IsSameSize(saved_shape),
+ errors::InvalidArgument(
+ "Shape in shape_and_slice spec does not match the shape in the "
+ "save file: ",
+ parsed_shape.DebugString(),
+ ", save file shape: ", saved_shape.DebugString()));
+ }
}
Tensor* t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &t));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(restore_index, output_shape, &t));
if (output_shape.num_elements() == 0) return;
@@ -239,9 +233,18 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
const Tensor& shape_and_slices,
gtl::ArraySlice<DataType> dtypes) {
const string& prefix_string = prefix.scalar<string>()();
+
const auto& tensor_names_flat = tensor_names.flat<string>();
const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+ // Sort lookup keys to improve locality when reading multiple tensors.
+ std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
+ std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
+ std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
+ [&tensor_names_flat](size_t a, size_t b) {
+ return tensor_names_flat(a) < tensor_names_flat(b);
+ });
+
BundleReader reader(Env::Default(), prefix_string);
TF_RETURN_IF_ERROR(reader.status());
@@ -250,9 +253,10 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
// within a fixed memory budget.
TensorShape restored_full_shape;
Tensor* restored_tensor = nullptr;
- for (size_t i = 0; i < tensor_names_flat.size(); ++i) {
+ for (auto i : sorted_name_idx) {
const string& tensor_name = tensor_names_flat(i);
const string& shape_and_slice = shape_and_slices_flat(i);
+
TF_RETURN_IF_ERROR(
reader.LookupTensorShape(tensor_name, &restored_full_shape));
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h
index 1e87e5c30b..5b74b586e8 100644
--- a/tensorflow/core/kernels/save_restore_tensor.h
+++ b/tensorflow/core/kernels/save_restore_tensor.h
@@ -37,18 +37,21 @@ void SaveTensors(
checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
bool save_slices);
-// Reads a tensor from the reader built from open_func() and produces it as
-// context->output(0). "preferred_shard" is the same the TensorSliceReader
-// preferred_shard parameter.
+// Reads a single tensor from the reader built from open_func() and produces
+// it as context->output(restore_index). "preferred_shard" is the same the
+// TensorSliceReader preferred_shard parameter.
//
// context must have the following inputs:
// 0: a single element string tensor that contains the file name.
-// 1: a single element string tensor that names the output to be restored.
+// 1: string tensor that names the outputs to be restored.
// If restore_slice is true:
-// 2: shape and slice specification of the tensor to restore.
+// 2: shape and slice specification of the tensors to restore.
+//
+// restore_index indicates the variable name and slice to lookup
+// in context(1) and (2).
void RestoreTensor(OpKernelContext* context,
checkpoint::TensorSliceReader::OpenTableFunction open_func,
- int preferred_shard, bool restore_slice);
+ int preferred_shard, bool restore_slice, int restore_index);
// V2 checkpoint format.
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index c665bc5b03..3acf290ea2 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -169,8 +169,14 @@ class RestoreV2 : public OpKernel {
paths.empty()) {
// Cannot find V2's metadata file, so "prefix_string" does not point to a
// V2 checkpoint. Invokes the V1 read path instead.
- RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
- /* preferred_shard */ -1, /* restore_slice */ true);
+ for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
+ RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
+ /* preferred_shard */ -1, /* restore_slice */ true,
+ /* restore_index */ i);
+ if (!context->status().ok()) {
+ return;
+ }
+ }
return;
}
// If found, invokes the V2 reader.
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index d0e54b7e47..eaec6f7c02 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -299,43 +299,6 @@ Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
return Status::OK();
}
-// Reads file[offset:offset+size) into destination[0:size). Each Read() copies
-// at most "buffer_size" bytes.
-//
-// REQUIRES: "file" contains at least "offset + size" bytes.
-// REQUIRES: "destination" contains at least "size" bytes.
-// On error, "destination" may contain garbage.
-Status ReadInputByChunk(const RandomAccessFile* file, size_t offset,
- size_t size, size_t buffer_size, char* destination) {
- if (size == 0) return Status::OK();
- CHECK_GT(size, 0);
- CHECK_GT(buffer_size, 0);
- size_t bytes_read = 0;
- StringPiece result;
-
- while (bytes_read < size) {
- const size_t desired_bytes = std::min(buffer_size, size - bytes_read);
- Status status = file->Read(offset + bytes_read, desired_bytes, &result,
- destination + bytes_read);
-
- if (!status.ok()) {
- return status;
- } else if (result.size() != desired_bytes) {
- return errors::DataLoss("Requested ", desired_bytes, " bytes but read ",
- result.size(), " bytes.");
- } else if (result.data() == destination + bytes_read) {
- // Data is already in the correct location.
- } else {
- // memmove is guaranteed to handle overlaps safely (although the src and
- // dst buffers should not overlap for this function).
- memmove(destination + bytes_read, result.data(), result.size());
- }
- bytes_read += result.size();
- }
- CHECK_EQ(bytes_read, size);
- return Status::OK();
-}
-
// Returns whether "slice_spec" is a full slice, with respect to the full shape.
//
// This can happen say, when "slice_spec" is
@@ -848,7 +811,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
buffered_file =
- new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */);
+ new io::InputBuffer(file.release(), 1024 << 10 /* 1024KB buffer */);
// The InputBuffer and RandomAccessFile objects are both released in dtor.
data_[entry.shard_id()] = buffered_file;
}
@@ -857,13 +820,10 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
uint32 actual_crc32c = 0;
if (DataTypeCanUseMemcpy(entry.dtype())) {
- // Important: ReadInputByChunk() bounds the readahead as min(buffer, actual
- // bytes needed). This is critical when reading small tensors, so we don't
- // rely on io::InputBuffer's blind buffering here.
char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
- TF_RETURN_IF_ERROR(ReadInputByChunk(buffered_file->file(), entry.offset(),
- entry.size(), 8 << 20 /* 8MB buffer */,
- backing_buffer));
+ size_t unused_bytes_read;
+ TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
+ &unused_bytes_read));
actual_crc32c = crc32c::Value(backing_buffer, entry.size());
} else if (entry.dtype() == DT_VARIANT) {
// Relies on io::InputBuffer's buffering, because we issue many neighboring
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2330229d56..2c59b82ebe 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -241,6 +241,34 @@ class BaseSaverBuilder(object):
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
+ def bulk_restore(self, filename_tensor, saveables, preferred_shard,
+ restore_sequentially):
+ """Restore all tensors contained in saveables.
+
+ By default, this issues separate calls to `restore_op` for each saveable.
+ Subclasses may override to load multiple saveables in a single call.
+
+ Args:
+ filename_tensor: String Tensor.
+ saveables: List of BaseSaverBuilder.SaveableObject objects.
+ preferred_shard: Int. Shard to open first when loading a sharded file.
+ restore_sequentially: Bool. If true, each restore is sequential.
+
+ Returns:
+ A list of Tensors resulting from reading 'saveable' from
+ 'filename'.
+
+ """
+ all_tensors = []
+ assign_ops = []
+ for saveable in saveables:
+ restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
+ with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
+ with ops.control_dependencies(restore_control_inputs):
+ all_tensors.extend(
+ self.restore_op(filename_tensor, saveable, preferred_shard))
+ return all_tensors
+
# pylint: disable=unused-argument
def restore_op(self, filename_tensor, saveable, preferred_shard):
"""Create ops to restore 'saveable'.
@@ -416,30 +444,32 @@ class BaseSaverBuilder(object):
Returns:
An Operation that restores the variables.
"""
+ all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
+ restore_sequentially)
+
assign_ops = []
+ idx = 0
+ # Load and optionally reshape on the CPU, as string tensors are not
+ # available on the GPU.
+ # TODO(touts): Re-enable restore on GPU when we can support annotating
+ # string tensors as "HostMemory" inputs.
for saveable in saveables:
- restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
- # Load and optionally reshape on the CPU, as string tensors are not
- # available on the GPU.
- # TODO(touts): Re-enable restore on GPU when we can support annotating
- # string tensors as "HostMemory" inputs.
- with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
- with ops.control_dependencies(restore_control_inputs):
- tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
- shapes = None
- if reshape:
- # Compute the shapes, let the restore op decide if and how to do
- # the reshape.
- shapes = []
- for spec in saveable.specs:
- v = spec.tensor
- shape = v.get_shape()
- if not shape.is_fully_defined():
- shape = array_ops.shape(v)
- shapes.append(shape)
- assign_ops.append(saveable.restore(tensors, shapes))
-
- # Create a Noop that has control dependencies from all the updates.
+ shapes = None
+ if reshape:
+ # Compute the shapes, let the restore op decide if and how to do
+ # the reshape.
+ shapes = []
+ for spec in saveable.specs:
+ v = spec.tensor
+ shape = v.get_shape()
+ if not shape.is_fully_defined():
+ shape = array_ops.shape(v)
+ shapes.append(shape)
+ saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
+ idx += len(saveable.specs)
+ assign_ops.append(saveable.restore(saveable_tensors, shapes))
+
+ # Create a Noop that has control dependencies from all the updates.
return control_flow_ops.group(*assign_ops, name=name)
def _AddShardedRestoreOps(self, filename_tensor, per_device,
@@ -797,6 +827,25 @@ class BaseSaverBuilder(object):
version=self._write_version)
+class BulkSaverBuilder(BaseSaverBuilder):
+ """SaverBuilder with support for bulk restoring multiple saveables."""
+
+ def bulk_restore(self, filename_tensor, saveables, preferred_shard,
+ restore_sequentially):
+
+ # Ignored: bulk restore is internally sequential.
+ del restore_sequentially
+ restore_specs = []
+ for saveable in saveables:
+ for spec in saveable.specs:
+ restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype))
+
+ names, slices, dtypes = zip(*restore_specs)
+ # Load all tensors onto CPU 0 for compatibility with existing code.
+ with ops.device("cpu:0"):
+ return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
+
+
def _get_saver_or_default():
"""Returns the saver from SAVERS collection, or creates a default one.
@@ -1261,6 +1310,7 @@ class Saver(object):
if not self.saver_def or context.in_eager_mode():
if self._builder is None:
self._builder = BaseSaverBuilder(self._write_version)
+
if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()