aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/string_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-25 11:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 11:28:52 -0700
commite959f7d25e5218b54172d1590fcd3d1d23d7eaf3 (patch)
treed105bd1268622c32f652d61fea54614f7205eb82 /tensorflow/contrib/lite/string_util.cc
parentaa2c22c7fb47b4f042e2e7f75460d2b8bd9db961 (diff)
Serialize strings properly when using TOCO for model conversion.
PiperOrigin-RevId: 194270132
Diffstat (limited to 'tensorflow/contrib/lite/string_util.cc')
-rw-r--r--tensorflow/contrib/lite/string_util.cc45
1 files changed, 31 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index cd41299d38..a89776b29f 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -24,7 +24,10 @@ namespace tflite {
namespace {
// Convenient method to get pointer to int32_t.
-int32_t* GetIntPtr(char* ptr) { return reinterpret_cast<int32_t*>(ptr); }
+const int32_t* GetIntPtr(const char* ptr) {
+ return reinterpret_cast<const int32_t*>(ptr);
+}
+
} // namespace
void DynamicBuffer::AddString(const char* str, size_t len) {
@@ -64,7 +67,7 @@ void DynamicBuffer::AddJoinedString(const std::vector<StringRef>& strings,
offset_.push_back(offset_.back() + total_len);
}
-void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+int DynamicBuffer::WriteToBuffer(char** buffer) {
// Allocate sufficient memory to tensor buffer.
int32_t num_strings = offset_.size() - 1;
// Total bytes include:
@@ -75,43 +78,57 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
int32_t bytes = data_.size() // size of content
+ sizeof(int32_t) * (num_strings + 2); // size of header
- // Output tensor will take over the ownership of tensor_buffer, and free it
- // during Interpreter destruction.
- char* tensor_buffer = static_cast<char*>(malloc(bytes));
+ // Caller will take ownership of buffer.
+ *buffer = reinterpret_cast<char*>(malloc(bytes));
// Set num of string
- memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
+ memcpy(*buffer, &num_strings, sizeof(int32_t));
// Set offset of strings.
int32_t start = sizeof(int32_t) * (num_strings + 2);
for (int i = 0; i < offset_.size(); i++) {
int32_t offset = start + offset_[i];
- memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
+ memcpy(*buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
}
// Copy data of strings.
- memcpy(tensor_buffer + start, data_.data(), data_.size());
+ memcpy(*buffer + start, data_.data(), data_.size());
+ return bytes;
+}
+
+void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+ char* tensor_buffer;
+ int bytes = WriteToBuffer(&tensor_buffer);
// Set tensor content pointer to tensor_buffer, and release original data.
auto dims = TfLiteIntArrayCreate(1);
- dims->data[0] = num_strings;
+ dims->data[0] = offset_.size() - 1; // Store number of strings.
TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
tensor);
}
+int GetStringCount(const char* raw_buffer) {
+ // The first integers in the raw buffer is the number of strings.
+ return *GetIntPtr(raw_buffer);
+}
+
int GetStringCount(const TfLiteTensor* tensor) {
// The first integers in the raw buffer is the number of strings.
- return *GetIntPtr(tensor->data.raw);
+ return GetStringCount(tensor->data.raw);
}
-StringRef GetString(const TfLiteTensor* tensor, int string_index) {
- int32_t* offset =
- GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1));
+StringRef GetString(const char* raw_buffer, int string_index) {
+ const int32_t* offset =
+ GetIntPtr(raw_buffer + sizeof(int32_t) * (string_index + 1));
return {
- tensor->data.raw + (*offset),
+ raw_buffer + (*offset),
(*(offset + 1)) - (*offset),
};
}
+StringRef GetString(const TfLiteTensor* tensor, int string_index) {
+ return GetString(tensor->data.raw, string_index);
+}
+
} // namespace tflite