diff options
-rw-r--r-- | tensorflow/core/kernels/save_restore_tensor.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_slice_reader.h | 17 |
2 files changed, 17 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 80d4901740..6b06cf650a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -216,9 +216,12 @@ void RestoreTensor(OpKernelContext* context, if (output_shape.num_elements() == 0) return; -#define READER_COPY(T) \ - case DataTypeToEnum<T>::value: \ - reader->CopySliceData(tensor_name, slice_to_load, t->flat<T>().data()); \ +#define READER_COPY(T) \ + case DataTypeToEnum<T>::value: \ + OP_REQUIRES(context, \ + reader->CopySliceData(tensor_name, slice_to_load, \ + t->flat<T>().data()), \ + errors::InvalidArgument("Error copying slice data")); \ break; switch (type) { diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h index eeb3129573..5932d59a15 100644 --- a/tensorflow/core/util/tensor_slice_reader.h +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -165,13 +165,18 @@ bool TensorSliceReader::CopySliceData(const string& name, CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; // We read a record in the corresponding sstable const string key = EncodeTensorNameSlice(name, slice_s); - CHECK(sss_[idx]->Get(key, &value)) - << "Failed to seek to the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!sss_[idx]->Get(key, &value)) { + VLOG(1) << "Failed to seek to the record for tensor " << name + << ", slice " << slice_s.DebugString() + << ": computed key = " << key; + return false; + } SavedTensorSlices sts; - CHECK(ParseProtoUnlimited(&sts, value)) - << "Failed to parse the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!ParseProtoUnlimited(&sts, value)) { + VLOG(1) << "Failed to parse the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + return false; + } CopyDataFromTensorSliceToTensorSlice( tss->shape(), slice_s, slice, checkpoint::TensorProtoData<T>(sts.data().data()), data); |