diff options
author | 2018-02-08 13:17:53 -0800 | |
---|---|---|
committer | 2018-02-08 13:22:08 -0800 | |
commit | caced55cbc205a9423a480cae0bb9e7a9a10f3a1 (patch) | |
tree | 392754e09b8e28225707da44d38a5ffc006c72c2 /tensorflow/c/c_api.cc | |
parent | 785ac4418d60e3f69115c2a05ee989d620635a71 (diff) |
C API: Fixes #7394
Ideally, when TF_NewTensor is provided with invalid arguments it would provide
a detailed error message. However, for now, to keep the existing API, signal
failure by returning nullptr.
PiperOrigin-RevId: 185040858
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b10af0f060..85f1d1639b 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -64,6 +64,7 @@ using tensorflow::AllocationDescription; using tensorflow::DataType; using tensorflow::Graph; using tensorflow::GraphDef; +using tensorflow::mutex_lock; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; using tensorflow::NewSession; @@ -77,6 +78,7 @@ using tensorflow::RunMetadata; using tensorflow::RunOptions; using tensorflow::Session; using tensorflow::Status; +using tensorflow::string; using tensorflow::Tensor; using tensorflow::TensorBuffer; using tensorflow::TensorId; @@ -87,8 +89,6 @@ using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; -using tensorflow::mutex_lock; -using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -199,11 +199,11 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste - // (any alignement requirements will be taken care of by TF_TensorToTensor + // (any alignment requirements will be taken care of by TF_TensorToTensor // and TF_TensorFromTensor). // - // Other types have the same represntation, so copy only if it is safe to do - // so. + // Other types have the same representation, so copy only if it is safe to + // do so. buf->data_ = allocate_tensor("TF_NewTensor", len); std::memcpy(buf->data_, data, len); buf->deallocator_ = deallocate_buffer; @@ -215,7 +215,13 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->deallocator_ = deallocator; buf->deallocator_arg_ = deallocator_arg; } - return new TF_Tensor{dtype, TensorShape(dimvec), buf}; + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; + size_t elem_size = TF_DataTypeSize(dtype); + if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { + delete ret; + return nullptr; + } + return ret; } TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { @@ -2148,7 +2154,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - // TOOD(skyewm): change to OutputTensor + // TODO(skyewm): change to OutputTensor tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); |