diff options
author | 2018-06-13 14:48:22 -0700 | |
---|---|---|
committer | 2018-06-13 14:50:40 -0700 | |
commit | a3273e090f7ea8401ea283ad052350aeffa5fdc1 (patch) | |
tree | 9497eab175c8d8ed52a466e5869d33433e6333da /tensorflow/contrib/lite/model.cc | |
parent | 2f7f04a7a03003e8fe345667ddf0b088032f0e03 (diff) |
Variable Tensor API for TF Lite.
PiperOrigin-RevId: 200457602
Diffstat (limited to 'tensorflow/contrib/lite/model.cc')
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index cd7b9bdabf..bc62e4cc2d 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -852,7 +852,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const char* buffer_ptr; TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + bool is_variable = tensor->is_variable(); if (buffer_ptr) { + if (is_variable) { + error_reporter_->Report( + "Tensor %d is a variable tensor with buffer. " + "It's not supported now.\n", + i); + status = kTfLiteError; + } + if (interpreter->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { @@ -861,8 +870,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite( - i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -946,6 +956,15 @@ TfLiteStatus InterpreterBuilder::operator()( if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); + std::vector<int> variables; + for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { + auto* tensor = (*interpreter)->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } + } + (**interpreter).SetVariables(variables); + return kTfLiteOk; } |