aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/model.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-13 14:48:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 14:50:40 -0700
commita3273e090f7ea8401ea283ad052350aeffa5fdc1 (patch)
tree9497eab175c8d8ed52a466e5869d33433e6333da /tensorflow/contrib/lite/model.cc
parent2f7f04a7a03003e8fe345667ddf0b088032f0e03 (diff)
Variable Tensor API for TF Lite.
PiperOrigin-RevId: 200457602
Diffstat (limited to 'tensorflow/contrib/lite/model.cc')
-rw-r--r--tensorflow/contrib/lite/model.cc23
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index cd7b9bdabf..bc62e4cc2d 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -852,7 +852,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
const char* buffer_ptr;
TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
+ bool is_variable = tensor->is_variable();
if (buffer_ptr) {
+ if (is_variable) {
+ error_reporter_->Report(
+ "Tensor %d is a variable tensor with buffer. "
+ "It's not supported now.\n",
+ i);
+ status = kTfLiteError;
+ }
+
if (interpreter->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
@@ -861,8 +870,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
} else {
- if (interpreter->SetTensorParametersReadWrite(
- i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
+ if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
+ dims, quantization,
+ is_variable) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
@@ -946,6 +956,15 @@ TfLiteStatus InterpreterBuilder::operator()(
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
+ std::vector<int> variables;
+ for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
+ auto* tensor = (*interpreter)->tensor(i);
+ if (tensor->is_variable) {
+ variables.push_back(i);
+ }
+ }
+ (**interpreter).SetVariables(variables);
+
return kTfLiteOk;
}