diff options
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.cc | 59 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.h | 17 | ||||
-rw-r--r-- | tensorflow/core/kernels/save_restore_v2_ops.cc | 52 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 7 |
5 files changed, 85 insertions, 51 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ba07cedee0..36b4def0ef 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -292,6 +292,7 @@ cc_library( ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/util/tensor_bundle", ], ) diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 9e0b59f125..e9fd108735 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" #include "tensorflow/core/util/tensor_slice_reader.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_writer.h" @@ -229,4 +230,62 @@ void RestoreTensor(OpKernelContext* context, #undef READER_COPY } +Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + gtl::ArraySlice<DataType> dtypes) { + const string& prefix_string = prefix.scalar<string>()(); + const auto& tensor_names_flat = tensor_names.flat<string>(); + const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); + + BundleReader reader(Env::Default(), prefix_string); + TF_RETURN_IF_ERROR(reader.status()); + + // TODO(zongheng): potential optimization: one Seek() in first lookup. + // TODO(zongheng): consider measuring speed and issuing concurrent lookups + // within a fixed memory budget. + TensorShape restored_full_shape; + Tensor* restored_tensor = nullptr; + for (size_t i = 0; i < tensor_names_flat.size(); ++i) { + const string& tensor_name = tensor_names_flat(i); + const string& shape_and_slice = shape_and_slices_flat(i); + TF_RETURN_IF_ERROR( + reader.LookupTensorShape(tensor_name, &restored_full_shape)); + + if (shape_and_slice.empty()) { + // Lookup the full tensor. + TF_RETURN_IF_ERROR( + context->allocate_output(i, restored_full_shape, &restored_tensor)); + TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor)); + } else { + // Lookup the slice. + TensorShape parsed_full_shape; + TensorSlice parsed_slice; + TensorShape parsed_slice_shape; + + TF_RETURN_IF_ERROR( + checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, + &parsed_slice, &parsed_slice_shape)); + if (!restored_full_shape.IsSameSize(parsed_full_shape)) { + return errors::InvalidArgument( + "Shape in shape_and_slice spec ", parsed_full_shape.DebugString(), + " does not match the shape stored in checkpoint: ", + restored_full_shape.DebugString()); + } + + TF_RETURN_IF_ERROR( + context->allocate_output(i, parsed_slice_shape, &restored_tensor)); + TF_RETURN_IF_ERROR( + reader.LookupSlice(tensor_name, parsed_slice, restored_tensor)); + } + if (dtypes[i] != restored_tensor->dtype()) { + return errors::InvalidArgument("Expected dtype ", + DataTypeString(dtypes[i]), + " does not equal restored dtype ", + DataTypeString(restored_tensor->dtype())); + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h index 34cd78c973..1e87e5c30b 100644 --- a/tensorflow/core/kernels/save_restore_tensor.h +++ b/tensorflow/core/kernels/save_restore_tensor.h @@ -23,6 +23,8 @@ namespace tensorflow { class OpKernelContext; +// Legacy / V1 checkpoint format. + // Save input tensors in *context to a writer built from builder_func(). // context must have the following inputs: // 0: a single element string tensor that contains the file name. @@ -48,6 +50,21 @@ void RestoreTensor(OpKernelContext* context, checkpoint::TensorSliceReader::OpenTableFunction open_func, int preferred_shard, bool restore_slice); +// V2 checkpoint format. + +// Invokes the V2 checkpoint read path to read tensors. +// +// "context" is only used for allocating outputs. In particular, the inputs are +// explicitly provided and not accessed via the "input(i)" methods. +// REQUIRES: +// * "prefix" has 1 element, DT_STRING. +// * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING. +// * "dtypes" has N elements, the datatypes of the to-restore tensors. +Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + gtl::ArraySlice<DataType> dtypes); + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_ diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 200f8a689d..4618536b64 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -155,8 +155,6 @@ class RestoreV2 : public OpKernel { shape_and_slices); const string& prefix_string = prefix.scalar<string>()(); - const auto& tensor_names_flat = tensor_names.flat<string>(); - const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); // Intention: we plan to use the RestoreV2 op as a backward-compatible // reader as we upgrade to the V2 format. This allows transparent upgrade. @@ -172,55 +170,9 @@ class RestoreV2 : public OpKernel { /* preferred_shard */ -1, /* restore_slice */ true); return; } - // If found, invokes the V2 reader. - BundleReader reader(env, prefix_string); - OP_REQUIRES_OK(context, reader.status()); - VLOG(1) << "BundleReader, prefix: " << prefix_string; - - // TODO(zongheng): potential optimization: one Seek() in first lookup. - // TODO(zongheng): consider measuring speed and issuing concurrent lookups - // within a fixed memory budget. - TensorShape restored_full_shape; - Tensor* restored_tensor = nullptr; - for (size_t i = 0; i < tensor_names_flat.size(); ++i) { - const string& tensor_name = tensor_names_flat(i); - const string& shape_and_slice = shape_and_slices_flat(i); - OP_REQUIRES_OK( - context, reader.LookupTensorShape(tensor_name, &restored_full_shape)); - - if (shape_and_slice.empty()) { - // Lookup the full tensor. - OP_REQUIRES_OK(context, context->allocate_output(i, restored_full_shape, - &restored_tensor)); - OP_REQUIRES_OK(context, reader.Lookup(tensor_name, restored_tensor)); - } else { - // Lookup the slice. - TensorShape parsed_full_shape; - TensorSlice parsed_slice; - TensorShape parsed_slice_shape; - - OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( - shape_and_slice, &parsed_full_shape, - &parsed_slice, &parsed_slice_shape)); - OP_REQUIRES(context, restored_full_shape.IsSameSize(parsed_full_shape), - errors::InvalidArgument( - "Shape in shape_and_slice spec ", - parsed_full_shape.DebugString(), - " does not match the shape stored in checkpoint: ", - restored_full_shape.DebugString())); - - OP_REQUIRES_OK(context, context->allocate_output(i, parsed_slice_shape, - &restored_tensor)); - OP_REQUIRES_OK(context, reader.LookupSlice(tensor_name, parsed_slice, - restored_tensor)); - } - OP_REQUIRES( - context, dtypes_[i] == restored_tensor->dtype(), - errors::InvalidArgument("Expected dtype ", DataTypeString(dtypes_[i]), - " does not equal restored dtype ", - DataTypeString(restored_tensor->dtype()))); - } + OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names, + shape_and_slices, dtypes_)); } private: diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 6cf6598f34..a7fee96681 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1348,11 +1348,16 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If the given `save_path` does not point to a file. + ValueError: DEPRECATED, do not rely on this Error. If the given + `save_path` does not point to a file. """ if self._is_empty: return + # NOTE(zongheng): checking at the Python layer prevents the underlying + # restore Op to handle more than one checkpoint format, potentially. This + # is a DEPRECATED error and might be removed in the future. + # # Performs this check only for V1, as the V2 restore op can read either a # V1 ckpt or a V2 ckpt, making this check invalid. if self.saver_def.version == saver_pb2.SaverDef.V1: |