aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-03-12 13:35:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 13:39:26 -0700
commit2057bf784770c55ab56bdbe5b96c233afbed50ce (patch)
treecc4b65fad3a34e4670b5919eb94c7c6ce2b789f9
parent7a6af158e972bfef4b23bf6812b5895abcdc5aef (diff)
[TFLite] Don't require a std::vector for Interpreter::SetTensorParameters*.
PiperOrigin-RevId: 188770522
-rw-r--r--tensorflow/contrib/lite/interpreter.cc27
-rw-r--r--tensorflow/contrib/lite/interpreter.h22
-rw-r--r--tensorflow/contrib/lite/util.cc16
-rw-r--r--tensorflow/contrib/lite/util.h8
4 files changed, 48 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index bbcd318efd..f03c1c9fe9 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -575,9 +575,9 @@ TfLiteStatus Interpreter::GetNodeAndRegistration(
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization,
- const char* buffer, size_t bytes, const Allocation* allocation) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
+ size_t bytes, const Allocation* allocation) {
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
@@ -585,23 +585,24 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// because their sizes change with the contents of the individual strings.
if (type != kTfLiteString) {
size_t required_bytes;
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
TfLiteTensor& tensor = context_.tensors[tensor_index];
- if (type == tensor.type && EqualVectorAndTfLiteIntArray(tensor.dims, dims)) {
+ if (type == tensor.type &&
+ EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
// Fast path which does not invalidate the invokable property.
TfLiteTensorDataFree(&tensor);
tensor.data.raw = const_cast<char*>(buffer);
- if (!tensor.dims) tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+ if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
tensor.params = quantization;
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation;
} else {
invokable_ = false;
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
}
@@ -613,8 +614,8 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization) {
invokable_ = false;
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
@@ -624,10 +625,10 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
// many bytes we will need based on the dimensions. String tensors are
// allocated dynamically and we can't know ahead of time how much space
// they will require.
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
}
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index f2d4a05164..7c5a195815 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -134,18 +134,34 @@ class Interpreter {
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
- TfLiteStatus SetTensorParametersReadOnly(
+ inline TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes,
+ const Allocation* allocation = nullptr) {
+ return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
+ dims.data(), quantization, buffer, bytes,
+ allocation);
+ };
+
+ TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
- TfLiteStatus SetTensorParametersReadWrite(
+ inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization);
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
+ dims.data(), quantization);
+ }
+ TfLiteStatus SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization);
// Functions to access tensor data
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index b7f31e2731..fb4af07d06 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -17,17 +17,21 @@ limitations under the License.
namespace tflite {
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
- TfLiteIntArray* output = TfLiteIntArrayCreate(input.size());
- for (size_t i = 0; i < input.size(); i++) {
- output->data[i] = input[i];
+ return ConvertArrayToTfLiteIntArray(input.size(), input.data());
+}
+
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) {
+ TfLiteIntArray* output = TfLiteIntArrayCreate(rank);
+ for (size_t i = 0; i < rank; i++) {
+ output->data[i] = dims[i];
}
return output;
}
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
- const std::vector<int>& b) {
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+ const int* b) {
if (!a) return false;
- if (a->size != b.size()) return false;
+ if (a->size != b_size) return false;
for (int i = 0; i < a->size; ++i) {
if (a->data[i] != b[i]) return false;
}
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f505d82a11..a34db35823 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -29,9 +29,11 @@ namespace tflite {
// Converts a `std::vector` to a `TfLiteIntArray`.
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
-// Checks whether a `TfLiteIntArray` and `std::vector` have matching elements.
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
- const std::vector<int>& b);
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims);
+
+// Checks whether a `TfLiteIntArray` and an int array have matching elements.
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+ const int* b);
} // namespace tflite