diff options
author | Russell Power <power@google.com> | 2018-01-04 12:02:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-04 12:06:05 -0800 |
commit | 782519a152c81873878e30c7791ccff5f6f534d1 (patch) | |
tree | 2bf7e0d4a62dc59bca547f2256b1f0dde839c5b0 | |
parent | dd0996f48fc7c580809c80c652a4bf726d3b2f3c (diff) |
Expand all saveable operations to generate a single C++ restore call.
This allows us to avoid repeated index lookups and perform a sequential scan of the index in the common case where we are doing a full restore, or a restore from a sub-model. It also dramatically reduces excessive restore parallelism.
Testing with a checkpoint with 1000 100x100 tensors, restoring from CNS drops from ~1m to ~5 seconds.
PiperOrigin-RevId: 180827583
-rw-r--r-- | tensorflow/core/kernels/restore_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.cc | 84 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.h | 15 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_v2_ops.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_bundle/tensor_bundle.cc | 48 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 94 |
6 files changed, 139 insertions, 116 deletions
diff --git a/tensorflow/core/kernels/restore_op.cc b/tensorflow/core/kernels/restore_op.cc index 0593a07b80..d9bbcb14ab 100644 --- a/tensorflow/core/kernels/restore_op.cc +++ b/tensorflow/core/kernels/restore_op.cc @@ -41,7 +41,7 @@ class RestoreOp : public OpKernel { } void Compute(OpKernelContext* context) override { RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, - preferred_shard_, false); + preferred_shard_, false, 0); } private: @@ -67,7 +67,7 @@ class RestoreSliceOp : public OpKernel { } void Compute(OpKernelContext* context) override { RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, - preferred_shard_, true); + preferred_shard_, true, 0); } private: diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 6b06cf650a..1700bcfca5 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/save_restore_tensor.h" +#include <numeric> #include <unordered_map> - #include <utility> #include <vector> -#include "tensorflow/core/kernels/save_restore_tensor.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -85,7 +85,17 @@ void SaveTensors( Status s; auto tensor_names_flat = tensor_names_t.flat<string>(); - for (int i = 0; i < N; ++i) { + // Process tensors in sorted name order. This allows us to avoid seeking + // during restoration in the common case where we are restoring a full + // checkpoint. + std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); + std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); + std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), + [&tensor_names_flat](size_t a, size_t b) { + return tensor_names_flat(a) < tensor_names_flat(b); + }); + + for (size_t i : sorted_name_idx) { const string& name = tensor_names_flat(i); const Tensor& input = context->input(i + kFixedInputs); TensorShape shape(input.shape()); @@ -132,7 +142,7 @@ void SaveTensors( void RestoreTensor(OpKernelContext* context, checkpoint::TensorSliceReader::OpenTableFunction open_func, - int preferred_shard, bool restore_slice) { + int preferred_shard, bool restore_slice, int restore_index) { const Tensor& file_pattern_t = context->input(0); { const int64 size = file_pattern_t.NumElements(); @@ -145,26 +155,7 @@ void RestoreTensor(OpKernelContext* context, const string& file_pattern = file_pattern_t.flat<string>()(0); const Tensor& tensor_name_t = context->input(1); - { - const int64 size = tensor_name_t.NumElements(); - OP_REQUIRES( - context, size == 1, - errors::InvalidArgument( - "Input 1 (tensor_name) must be a string scalar; got a tensor of ", - size, "elements")); - } - const string& tensor_name = tensor_name_t.flat<string>()(0); - - const string* tensor_shape_and_slice_ptr = nullptr; - if (restore_slice) { - const Tensor& tensor_shape_and_slice_t = context->input(2); - OP_REQUIRES( - context, tensor_shape_and_slice_t.NumElements() == 1, - errors::InvalidArgument("Expected 1 element for the tensor " - "shape and slice but got ", - tensor_shape_and_slice_t.NumElements())); - tensor_shape_and_slice_ptr = tensor_shape_and_slice_t.flat<string>().data(); - } + const string& tensor_name = tensor_name_t.flat<string>()(restore_index); // If we cannot find a cached reader we will allocate our own. std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; @@ -187,7 +178,7 @@ void RestoreTensor(OpKernelContext* context, errors::NotFound("Tensor name \"", tensor_name, "\" not found in checkpoint files ", file_pattern)); OP_REQUIRES( - context, type == context->expected_output_dtype(0), + context, type == context->expected_output_dtype(restore_index), errors::InvalidArgument("Expected to restore a tensor of type ", DataTypeString(context->expected_output_dtype(0)), ", got a tensor of type ", DataTypeString(type), @@ -196,23 +187,26 @@ void RestoreTensor(OpKernelContext* context, // Shape of the output and slice to load. TensorShape output_shape(saved_shape); TensorSlice slice_to_load(saved_shape.dims()); - if (restore_slice && !tensor_shape_and_slice_ptr[0].empty()) { - const string& shape_spec = tensor_shape_and_slice_ptr[0]; - TensorShape parsed_shape; - OP_REQUIRES_OK( - context, checkpoint::ParseShapeAndSlice(shape_spec, &parsed_shape, - &slice_to_load, &output_shape)); - OP_REQUIRES( - context, parsed_shape.IsSameSize(saved_shape), - errors::InvalidArgument( - "Shape in shape_and_slice spec does not match the shape in the " - "save file: ", - parsed_shape.DebugString(), ", save file shape: ", - saved_shape.DebugString())); + if (restore_slice) { + const string& shape_spec = context->input(2).flat<string>()(restore_index); + if (!shape_spec.empty()) { + TensorShape parsed_shape; + OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( + shape_spec, &parsed_shape, &slice_to_load, + &output_shape)); + OP_REQUIRES( + context, parsed_shape.IsSameSize(saved_shape), + errors::InvalidArgument( + "Shape in shape_and_slice spec does not match the shape in the " + "save file: ", + parsed_shape.DebugString(), + ", save file shape: ", saved_shape.DebugString())); + } } Tensor* t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &t)); + OP_REQUIRES_OK(context, + context->allocate_output(restore_index, output_shape, &t)); if (output_shape.num_elements() == 0) return; @@ -239,9 +233,18 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, const Tensor& shape_and_slices, gtl::ArraySlice<DataType> dtypes) { const string& prefix_string = prefix.scalar<string>()(); + const auto& tensor_names_flat = tensor_names.flat<string>(); const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); + // Sort lookup keys to improve locality when reading multiple tensors. + std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); + std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); + std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), + [&tensor_names_flat](size_t a, size_t b) { + return tensor_names_flat(a) < tensor_names_flat(b); + }); + BundleReader reader(Env::Default(), prefix_string); TF_RETURN_IF_ERROR(reader.status()); @@ -250,9 +253,10 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, // within a fixed memory budget. TensorShape restored_full_shape; Tensor* restored_tensor = nullptr; - for (size_t i = 0; i < tensor_names_flat.size(); ++i) { + for (auto i : sorted_name_idx) { const string& tensor_name = tensor_names_flat(i); const string& shape_and_slice = shape_and_slices_flat(i); + TF_RETURN_IF_ERROR( reader.LookupTensorShape(tensor_name, &restored_full_shape)); diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h index 1e87e5c30b..5b74b586e8 100644 --- a/tensorflow/core/kernels/save_restore_tensor.h +++ b/tensorflow/core/kernels/save_restore_tensor.h @@ -37,18 +37,21 @@ void SaveTensors( checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, bool save_slices); -// Reads a tensor from the reader built from open_func() and produces it as -// context->output(0). "preferred_shard" is the same the TensorSliceReader -// preferred_shard parameter. +// Reads a single tensor from the reader built from open_func() and produces +// it as context->output(restore_index). "preferred_shard" is the same the +// TensorSliceReader preferred_shard parameter. // // context must have the following inputs: // 0: a single element string tensor that contains the file name. -// 1: a single element string tensor that names the output to be restored. +// 1: string tensor that names the outputs to be restored. // If restore_slice is true: -// 2: shape and slice specification of the tensor to restore. +// 2: shape and slice specification of the tensors to restore. +// +// restore_index indicates the variable name and slice to lookup +// in context(1) and (2). void RestoreTensor(OpKernelContext* context, checkpoint::TensorSliceReader::OpenTableFunction open_func, - int preferred_shard, bool restore_slice); + int preferred_shard, bool restore_slice, int restore_index); // V2 checkpoint format. diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index c665bc5b03..3acf290ea2 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -169,8 +169,14 @@ class RestoreV2 : public OpKernel { paths.empty()) { // Cannot find V2's metadata file, so "prefix_string" does not point to a // V2 checkpoint. Invokes the V1 read path instead. - RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, - /* preferred_shard */ -1, /* restore_slice */ true); + for (size_t i = 0; i < tensor_names.NumElements(); ++i) { + RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, + /* preferred_shard */ -1, /* restore_slice */ true, + /* restore_index */ i); + if (!context->status().ok()) { + return; + } + } return; } // If found, invokes the V2 reader. diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index d0e54b7e47..eaec6f7c02 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -299,43 +299,6 @@ Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out, return Status::OK(); } -// Reads file[offset:offset+size) into destination[0:size). Each Read() copies -// at most "buffer_size" bytes. -// -// REQUIRES: "file" contains at least "offset + size" bytes. -// REQUIRES: "destination" contains at least "size" bytes. -// On error, "destination" may contain garbage. -Status ReadInputByChunk(const RandomAccessFile* file, size_t offset, - size_t size, size_t buffer_size, char* destination) { - if (size == 0) return Status::OK(); - CHECK_GT(size, 0); - CHECK_GT(buffer_size, 0); - size_t bytes_read = 0; - StringPiece result; - - while (bytes_read < size) { - const size_t desired_bytes = std::min(buffer_size, size - bytes_read); - Status status = file->Read(offset + bytes_read, desired_bytes, &result, - destination + bytes_read); - - if (!status.ok()) { - return status; - } else if (result.size() != desired_bytes) { - return errors::DataLoss("Requested ", desired_bytes, " bytes but read ", - result.size(), " bytes."); - } else if (result.data() == destination + bytes_read) { - // Data is already in the correct location. - } else { - // memmove is guaranteed to handle overlaps safely (although the src and - // dst buffers should not overlap for this function). - memmove(destination + bytes_read, result.data(), result.size()); - } - bytes_read += result.size(); - } - CHECK_EQ(bytes_read, size); - return Status::OK(); -} - // Returns whether "slice_spec" is a full slice, with respect to the full shape. // // This can happen say, when "slice_spec" is @@ -848,7 +811,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); buffered_file = - new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */); + new io::InputBuffer(file.release(), 1024 << 10 /* 1024KB buffer */); // The InputBuffer and RandomAccessFile objects are both released in dtor. data_[entry.shard_id()] = buffered_file; } @@ -857,13 +820,10 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset())); uint32 actual_crc32c = 0; if (DataTypeCanUseMemcpy(entry.dtype())) { - // Important: ReadInputByChunk() bounds the readahead as min(buffer, actual - // bytes needed). This is critical when reading small tensors, so we don't - // rely on io::InputBuffer's blind buffering here. char* backing_buffer = const_cast<char*>((ret->tensor_data().data())); - TF_RETURN_IF_ERROR(ReadInputByChunk(buffered_file->file(), entry.offset(), - entry.size(), 8 << 20 /* 8MB buffer */, - backing_buffer)); + size_t unused_bytes_read; + TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer, + &unused_bytes_read)); actual_crc32c = crc32c::Value(backing_buffer, entry.size()); } else if (entry.dtype() == DT_VARIANT) { // Relies on io::InputBuffer's buffering, because we issue many neighboring diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2330229d56..2c59b82ebe 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -241,6 +241,34 @@ class BaseSaverBuilder(object): else: raise RuntimeError("Unexpected write_version: " + self._write_version) + def bulk_restore(self, filename_tensor, saveables, preferred_shard, + restore_sequentially): + """Restore all tensors contained in saveables. + + By default, this issues separate calls to `restore_op` for each saveable. + Subclasses may override to load multiple saveables in a single call. + + Args: + filename_tensor: String Tensor. + saveables: List of BaseSaverBuilder.SaveableObject objects. + preferred_shard: Int. Shard to open first when loading a sharded file. + restore_sequentially: Bool. If true, each restore is sequential. + + Returns: + A list of Tensors resulting from reading 'saveable' from + 'filename'. + + """ + all_tensors = [] + assign_ops = [] + for saveable in saveables: + restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] + with ops.device(_set_cpu0(saveable.device) if saveable.device else None): + with ops.control_dependencies(restore_control_inputs): + all_tensors.extend( + self.restore_op(filename_tensor, saveable, preferred_shard)) + return all_tensors + # pylint: disable=unused-argument def restore_op(self, filename_tensor, saveable, preferred_shard): """Create ops to restore 'saveable'. @@ -416,30 +444,32 @@ class BaseSaverBuilder(object): Returns: An Operation that restores the variables. """ + all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, + restore_sequentially) + assign_ops = [] + idx = 0 + # Load and optionally reshape on the CPU, as string tensors are not + # available on the GPU. + # TODO(touts): Re-enable restore on GPU when we can support annotating + # string tensors as "HostMemory" inputs. for saveable in saveables: - restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] - # Load and optionally reshape on the CPU, as string tensors are not - # available on the GPU. - # TODO(touts): Re-enable restore on GPU when we can support annotating - # string tensors as "HostMemory" inputs. - with ops.device(_set_cpu0(saveable.device) if saveable.device else None): - with ops.control_dependencies(restore_control_inputs): - tensors = self.restore_op(filename_tensor, saveable, preferred_shard) - shapes = None - if reshape: - # Compute the shapes, let the restore op decide if and how to do - # the reshape. - shapes = [] - for spec in saveable.specs: - v = spec.tensor - shape = v.get_shape() - if not shape.is_fully_defined(): - shape = array_ops.shape(v) - shapes.append(shape) - assign_ops.append(saveable.restore(tensors, shapes)) - - # Create a Noop that has control dependencies from all the updates. + shapes = None + if reshape: + # Compute the shapes, let the restore op decide if and how to do + # the reshape. + shapes = [] + for spec in saveable.specs: + v = spec.tensor + shape = v.get_shape() + if not shape.is_fully_defined(): + shape = array_ops.shape(v) + shapes.append(shape) + saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] + idx += len(saveable.specs) + assign_ops.append(saveable.restore(saveable_tensors, shapes)) + + # Create a Noop that has control dependencies from all the updates. return control_flow_ops.group(*assign_ops, name=name) def _AddShardedRestoreOps(self, filename_tensor, per_device, @@ -797,6 +827,25 @@ class BaseSaverBuilder(object): version=self._write_version) +class BulkSaverBuilder(BaseSaverBuilder): + """SaverBuilder with support for bulk restoring multiple saveables.""" + + def bulk_restore(self, filename_tensor, saveables, preferred_shard, + restore_sequentially): + + # Ignored: bulk restore is internally sequential. + del restore_sequentially + restore_specs = [] + for saveable in saveables: + for spec in saveable.specs: + restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype)) + + names, slices, dtypes = zip(*restore_specs) + # Load all tensors onto CPU 0 for compatibility with existing code. + with ops.device("cpu:0"): + return io_ops.restore_v2(filename_tensor, names, slices, dtypes) + + def _get_saver_or_default(): """Returns the saver from SAVERS collection, or creates a default one. @@ -1261,6 +1310,7 @@ class Saver(object): if not self.saver_def or context.in_eager_mode(): if self._builder is None: self._builder = BaseSaverBuilder(self._write_version) + if self._var_list is None: # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() |