aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-03-12 13:35:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 13:39:26 -0700
commit2057bf784770c55ab56bdbe5b96c233afbed50ce (patch)
treecc4b65fad3a34e4670b5919eb94c7c6ce2b789f9 /tensorflow/contrib/lite/interpreter.cc
parent7a6af158e972bfef4b23bf6812b5895abcdc5aef (diff)
[TFLite] Don't require a std::vector for Interpreter::SetTensorParameters*.
PiperOrigin-RevId: 188770522
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc27
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index bbcd318efd..f03c1c9fe9 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -575,9 +575,9 @@ TfLiteStatus Interpreter::GetNodeAndRegistration(
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization,
- const char* buffer, size_t bytes, const Allocation* allocation) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
+ size_t bytes, const Allocation* allocation) {
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
@@ -585,23 +585,24 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// because their sizes change with the contents of the individual strings.
if (type != kTfLiteString) {
size_t required_bytes;
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
TfLiteTensor& tensor = context_.tensors[tensor_index];
- if (type == tensor.type && EqualVectorAndTfLiteIntArray(tensor.dims, dims)) {
+ if (type == tensor.type &&
+ EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
// Fast path which does not invalidate the invokable property.
TfLiteTensorDataFree(&tensor);
tensor.data.raw = const_cast<char*>(buffer);
- if (!tensor.dims) tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+ if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
tensor.params = quantization;
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation;
} else {
invokable_ = false;
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
}
@@ -613,8 +614,8 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization) {
invokable_ = false;
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
@@ -624,10 +625,10 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
// many bytes we will need based on the dimensions. String tensors are
// allocated dynamically and we can't know ahead of time how much space
// they will require.
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
}
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,