diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-02 14:35:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-02 14:39:02 -0700 |
commit | ddba1e0aadabe26063a28c5d1c48e2cfce44e30f (patch) | |
tree | 8d6682e2d3a860c6c24ddfc99c9cbfbed93830fd | |
parent | 7d5cbd78a54319eeb45bca2e239ec037997dad20 (diff) |
Replace CHECKs in v1 checkpoint loading codepath with returning errors.
PiperOrigin-RevId: 167392822
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_slice_reader.h | 17 |
2 files changed, 17 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 80d4901740..6b06cf650a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -216,9 +216,12 @@ void RestoreTensor(OpKernelContext* context, if (output_shape.num_elements() == 0) return; -#define READER_COPY(T) \ - case DataTypeToEnum<T>::value: \ - reader->CopySliceData(tensor_name, slice_to_load, t->flat<T>().data()); \ +#define READER_COPY(T) \ + case DataTypeToEnum<T>::value: \ + OP_REQUIRES(context, \ + reader->CopySliceData(tensor_name, slice_to_load, \ + t->flat<T>().data()), \ + errors::InvalidArgument("Error copying slice data")); \ break; switch (type) { diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h index eeb3129573..5932d59a15 100644 --- a/tensorflow/core/util/tensor_slice_reader.h +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -165,13 +165,18 @@ bool TensorSliceReader::CopySliceData(const string& name, CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; // We read a record in the corresponding sstable const string key = EncodeTensorNameSlice(name, slice_s); - CHECK(sss_[idx]->Get(key, &value)) - << "Failed to seek to the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!sss_[idx]->Get(key, &value)) { + VLOG(1) << "Failed to seek to the record for tensor " << name + << ", slice " << slice_s.DebugString() + << ": computed key = " << key; + return false; + } SavedTensorSlices sts; - CHECK(ParseProtoUnlimited(&sts, value)) - << "Failed to parse the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!ParseProtoUnlimited(&sts, value)) { + VLOG(1) << "Failed to parse the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + return false; + } CopyDataFromTensorSliceToTensorSlice( tss->shape(), slice_s, slice, checkpoint::TensorProtoData<T>(sts.data().data()), data); |