diff options
-rw-r--r-- | tensorflow/c/c_api.cc | 6 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 19 |
2 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 2cb5a1cc3a..cf524c8bc9 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -220,6 +220,10 @@ void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { const size_t sz = TF_StringEncodedSize(src_len); + if (sz < src_len) { + status->status = InvalidArgument("src string is too large to encode"); + return 0; + } if (dst_len < sz) { status->status = InvalidArgument("dst_len (", dst_len, ") too small to encode a ", @@ -428,10 +432,10 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { const tensorflow::string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, &status); + CHECK(status.status.ok()); dst += consumed; dst_len -= consumed; } - CHECK(status.status.ok()); CHECK_EQ(dst, base + size); auto dims = src.shape().dim_sizes(); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 828ef23ebb..3b969551e4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -235,27 +235,28 @@ extern size_t TF_TensorByteSize(const TF_Tensor*); extern void* TF_TensorData(const TF_Tensor*); // -------------------------------------------------------------------------- -// Encode the string "src" ("src_len" bytes long) into "dst" in the format -// required by TF_STRING tensors. Does not write to memory more than "dst_len" -// bytes beyond "*dst". "dst_len" should be at least +// Encode the string `src` (`src_len` bytes long) into `dst` in the format +// required by TF_STRING tensors. Does not write to memory more than `dst_len` +// bytes beyond `*dst`. `dst_len` should be at least // TF_StringEncodedSize(src_len). // // On success returns the size in bytes of the encoded string. +// Returns an error into `status` otherwise. extern size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status); // Decode a string encoded using TF_StringEncode. // -// On success, sets "*dst" to the start of the decoded string and "*dst_len" to -// its length. Returns the number of bytes starting at "src" consumed while -// decoding. "*dst" points to memory within the encoded buffer. On failure, -// "*dst" and "*dst_len" are undefined. +// On success, sets `*dst` to the start of the decoded string and `*dst_len` to +// its length. Returns the number of bytes starting at `src` consumed while +// decoding. `*dst` points to memory within the encoded buffer. On failure, +// `*dst` and `*dst_len` are undefined and an error is set in `status`. // -// Does not read memory pointed to by "limit" or beyond. +// Does not read memory more than `src_len` bytes beyond `src`. extern size_t TF_DecodeString(const char* src, size_t src_len, char** dst, size_t* dst_len, TF_Status* status); -// Return the size in bytes required to encode a string "len" bytes long into a +// Return the size in bytes required to encode a string `len` bytes long into a // TF_STRING tensor. extern size_t TF_StringEncodedSize(size_t len); |