aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng@google.com>2016-10-11 10:21:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-11 11:31:29 -0700
commitacff1637e487ecf9021d82f1a05b2dbb5f9048e6 (patch)
tree24d4274c9c62a14c7977c2e9345188b9e54b318a /tensorflow/core
parentb718fd6ad8cebc470fcc7e53bb6902168edd5587 (diff)
TF Checkpoint V2: factor out the V2 read path into save_restore_tensor.h.
Change: 135818288
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc59
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.h17
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc52
4 files changed, 79 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ba07cedee0..36b4def0ef 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -292,6 +292,7 @@ cc_library(
":bounds_check",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core/util/tensor_bundle",
],
)
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index 9e0b59f125..e9fd108735 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
@@ -229,4 +230,62 @@ void RestoreTensor(OpKernelContext* context,
#undef READER_COPY
}
+Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
+ const Tensor& tensor_names,
+ const Tensor& shape_and_slices,
+ gtl::ArraySlice<DataType> dtypes) {
+ const string& prefix_string = prefix.scalar<string>()();
+ const auto& tensor_names_flat = tensor_names.flat<string>();
+ const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+
+ BundleReader reader(Env::Default(), prefix_string);
+ TF_RETURN_IF_ERROR(reader.status());
+
+ // TODO(zongheng): potential optimization: one Seek() in first lookup.
+ // TODO(zongheng): consider measuring speed and issuing concurrent lookups
+ // within a fixed memory budget.
+ TensorShape restored_full_shape;
+ Tensor* restored_tensor = nullptr;
+ for (size_t i = 0; i < tensor_names_flat.size(); ++i) {
+ const string& tensor_name = tensor_names_flat(i);
+ const string& shape_and_slice = shape_and_slices_flat(i);
+ TF_RETURN_IF_ERROR(
+ reader.LookupTensorShape(tensor_name, &restored_full_shape));
+
+ if (shape_and_slice.empty()) {
+ // Lookup the full tensor.
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(i, restored_full_shape, &restored_tensor));
+ TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor));
+ } else {
+ // Lookup the slice.
+ TensorShape parsed_full_shape;
+ TensorSlice parsed_slice;
+ TensorShape parsed_slice_shape;
+
+ TF_RETURN_IF_ERROR(
+ checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
+ &parsed_slice, &parsed_slice_shape));
+ if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
+ return errors::InvalidArgument(
+ "Shape in shape_and_slice spec ", parsed_full_shape.DebugString(),
+ " does not match the shape stored in checkpoint: ",
+ restored_full_shape.DebugString());
+ }
+
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(i, parsed_slice_shape, &restored_tensor));
+ TF_RETURN_IF_ERROR(
+ reader.LookupSlice(tensor_name, parsed_slice, restored_tensor));
+ }
+ if (dtypes[i] != restored_tensor->dtype()) {
+ return errors::InvalidArgument("Expected dtype ",
+ DataTypeString(dtypes[i]),
+ " does not equal restored dtype ",
+ DataTypeString(restored_tensor->dtype()));
+ }
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h
index 34cd78c973..1e87e5c30b 100644
--- a/tensorflow/core/kernels/save_restore_tensor.h
+++ b/tensorflow/core/kernels/save_restore_tensor.h
@@ -23,6 +23,8 @@ namespace tensorflow {
class OpKernelContext;
+// Legacy / V1 checkpoint format.
+
// Save input tensors in *context to a writer built from builder_func().
// context must have the following inputs:
// 0: a single element string tensor that contains the file name.
@@ -48,6 +50,21 @@ void RestoreTensor(OpKernelContext* context,
checkpoint::TensorSliceReader::OpenTableFunction open_func,
int preferred_shard, bool restore_slice);
+// V2 checkpoint format.
+
+// Invokes the V2 checkpoint read path to read tensors.
+//
+// "context" is only used for allocating outputs. In particular, the inputs are
+// explicitly provided and not accessed via the "input(i)" methods.
+// REQUIRES:
+// * "prefix" has 1 element, DT_STRING.
+// * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING.
+// * "dtypes" has N elements, the datatypes of the to-restore tensors.
+Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
+ const Tensor& tensor_names,
+ const Tensor& shape_and_slices,
+ gtl::ArraySlice<DataType> dtypes);
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index 200f8a689d..4618536b64 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -155,8 +155,6 @@ class RestoreV2 : public OpKernel {
shape_and_slices);
const string& prefix_string = prefix.scalar<string>()();
- const auto& tensor_names_flat = tensor_names.flat<string>();
- const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
// Intention: we plan to use the RestoreV2 op as a backward-compatible
// reader as we upgrade to the V2 format. This allows transparent upgrade.
@@ -172,55 +170,9 @@ class RestoreV2 : public OpKernel {
/* preferred_shard */ -1, /* restore_slice */ true);
return;
}
-
// If found, invokes the V2 reader.
- BundleReader reader(env, prefix_string);
- OP_REQUIRES_OK(context, reader.status());
- VLOG(1) << "BundleReader, prefix: " << prefix_string;
-
- // TODO(zongheng): potential optimization: one Seek() in first lookup.
- // TODO(zongheng): consider measuring speed and issuing concurrent lookups
- // within a fixed memory budget.
- TensorShape restored_full_shape;
- Tensor* restored_tensor = nullptr;
- for (size_t i = 0; i < tensor_names_flat.size(); ++i) {
- const string& tensor_name = tensor_names_flat(i);
- const string& shape_and_slice = shape_and_slices_flat(i);
- OP_REQUIRES_OK(
- context, reader.LookupTensorShape(tensor_name, &restored_full_shape));
-
- if (shape_and_slice.empty()) {
- // Lookup the full tensor.
- OP_REQUIRES_OK(context, context->allocate_output(i, restored_full_shape,
- &restored_tensor));
- OP_REQUIRES_OK(context, reader.Lookup(tensor_name, restored_tensor));
- } else {
- // Lookup the slice.
- TensorShape parsed_full_shape;
- TensorSlice parsed_slice;
- TensorShape parsed_slice_shape;
-
- OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
- shape_and_slice, &parsed_full_shape,
- &parsed_slice, &parsed_slice_shape));
- OP_REQUIRES(context, restored_full_shape.IsSameSize(parsed_full_shape),
- errors::InvalidArgument(
- "Shape in shape_and_slice spec ",
- parsed_full_shape.DebugString(),
- " does not match the shape stored in checkpoint: ",
- restored_full_shape.DebugString()));
-
- OP_REQUIRES_OK(context, context->allocate_output(i, parsed_slice_shape,
- &restored_tensor));
- OP_REQUIRES_OK(context, reader.LookupSlice(tensor_name, parsed_slice,
- restored_tensor));
- }
- OP_REQUIRES(
- context, dtypes_[i] == restored_tensor->dtype(),
- errors::InvalidArgument("Expected dtype ", DataTypeString(dtypes_[i]),
- " does not equal restored dtype ",
- DataTypeString(restored_tensor->dtype())));
- }
+ OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
+ shape_and_slices, dtypes_));
}
private: