aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-02 14:35:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-02 14:39:02 -0700
commitddba1e0aadabe26063a28c5d1c48e2cfce44e30f (patch)
tree8d6682e2d3a860c6c24ddfc99c9cbfbed93830fd
parent7d5cbd78a54319eeb45bca2e239ec037997dad20 (diff)
Replace CHECKs in v1 checkpoint loading codepath with returning errors.
PiperOrigin-RevId: 167392822
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc9
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h17
2 files changed, 17 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index 80d4901740..6b06cf650a 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -216,9 +216,12 @@ void RestoreTensor(OpKernelContext* context,
if (output_shape.num_elements() == 0) return;
-#define READER_COPY(T) \
- case DataTypeToEnum<T>::value: \
- reader->CopySliceData(tensor_name, slice_to_load, t->flat<T>().data()); \
+#define READER_COPY(T) \
+ case DataTypeToEnum<T>::value: \
+ OP_REQUIRES(context, \
+ reader->CopySliceData(tensor_name, slice_to_load, \
+ t->flat<T>().data()), \
+ errors::InvalidArgument("Error copying slice data")); \
break;
switch (type) {
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
index eeb3129573..5932d59a15 100644
--- a/tensorflow/core/util/tensor_slice_reader.h
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -165,13 +165,18 @@ bool TensorSliceReader::CopySliceData(const string& name,
CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
// We read a record in the corresponding sstable
const string key = EncodeTensorNameSlice(name, slice_s);
- CHECK(sss_[idx]->Get(key, &value))
- << "Failed to seek to the record for tensor " << name << ", slice "
- << slice_s.DebugString() << ": computed key = " << key;
+ if (!sss_[idx]->Get(key, &value)) {
+ VLOG(1) << "Failed to seek to the record for tensor " << name
+ << ", slice " << slice_s.DebugString()
+ << ": computed key = " << key;
+ return false;
+ }
SavedTensorSlices sts;
- CHECK(ParseProtoUnlimited(&sts, value))
- << "Failed to parse the record for tensor " << name << ", slice "
- << slice_s.DebugString() << ": computed key = " << key;
+ if (!ParseProtoUnlimited(&sts, value)) {
+ VLOG(1) << "Failed to parse the record for tensor " << name << ", slice "
+ << slice_s.DebugString() << ": computed key = " << key;
+ return false;
+ }
CopyDataFromTensorSliceToTensorSlice(
tss->shape(), slice_s, slice,
checkpoint::TensorProtoData<T>(sts.data().data()), data);