diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-06-13 14:48:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-13 14:50:40 -0700 |
commit | a3273e090f7ea8401ea283ad052350aeffa5fdc1 (patch) | |
tree | 9497eab175c8d8ed52a466e5869d33433e6333da /tensorflow/contrib/lite/interpreter.cc | |
parent | 2f7f04a7a03003e8fe345667ddf0b088032f0e03 (diff) |
Variable Tensor API for TF Lite.
PiperOrigin-RevId: 200457602
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 2f8205444d..3287f9c4fd 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo { const std::vector<int>& outputs() const override { return interpreter_->outputs(); } + const std::vector<int>& variables() const override { + return interpreter_->variables(); + } public: Interpreter* interpreter_; @@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) { return kTfLiteOk; } +TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) { + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(), + variables.size())); + variables_ = std::move(variables); + return kTfLiteOk; +} + TfLiteStatus Interpreter::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -370,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() { } TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + if (state_ == kStateUninvokable) { state_ = kStateInvokable; } @@ -378,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() { return kTfLiteOk; } +// TODO(ycling): Consider to provide other functions to initialize variable +// tensors to non-zero values. +TfLiteStatus Interpreter::ResetVariableTensorsToZero() { + for (auto& tensor : tensors_) { + if (!tensor.is_variable) { + continue; + } + + // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be + // allocated after the initial `PrepareOpsAndTensors()` is called. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, + kTfLiteArenaRwPersistent); + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + + memset(tensor.data.raw, 0, tensor.bytes); + } + return kTfLiteOk; +} + TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector<int>& inputs, const std::vector<int>& outputs, const char* init_data, size_t init_data_size, void* builtin_data, @@ -690,7 +720,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast<char*>(buffer), bytes, - kTfLiteMmapRo, allocation, &tensor); + kTfLiteMmapRo, allocation, false, &tensor); } return kTfLiteOk; } @@ -701,7 +731,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( &context_, @@ -719,11 +749,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims, rank, &required_bytes)); } + + TfLiteAllocationType allocation_type = kTfLiteArenaRw; + if (type == kTfLiteString) { + if (is_variable) { + // We don't have a real use case for string variable tensor. + ReportError(&context_, "String variable tensor isn't supported."); + return kTfLiteError; + } + allocation_type = kTfLiteDynamic; + } else if (is_variable) { + allocation_type = kTfLiteArenaRwPersistent; + } + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, - /*buffer=*/nullptr, required_bytes, - type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, - nullptr, &context_.tensors[tensor_index]); + /*buffer=*/nullptr, required_bytes, allocation_type, + nullptr, is_variable, &context_.tensors[tensor_index]); return kTfLiteOk; } @@ -739,7 +781,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteDynamic) { + tensor->allocation_type == kTfLiteDynamic || + tensor->allocation_type == kTfLiteArenaRwPersistent) { if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, |