diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-03-19 14:27:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-19 14:37:01 -0700 |
commit | ff43dff34ab525dd333128c73ebfb0f9723c34c0 (patch) | |
tree | 395c315428a0f084cfce47a562154ec15474e047 /tensorflow/contrib/lite/interpreter.cc | |
parent | e613e0844a95814457f3530eedb9baf812cf1e87 (diff) |
TFLite Delegate: Add an `allow_dynamic_tensors` parameter.
PiperOrigin-RevId: 189641833
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 80 |
1 files changed, 71 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index cee57bba5e..937c185b0a 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -356,7 +356,11 @@ TfLiteStatus Interpreter::AllocateTensors() { } TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - invokable_ = true; + if (state_ == kStateUninvokable) { + state_ = kStateInvokable; + } + TF_LITE_ENSURE(&context_, state_ == kStateInvokable || + state_ == kStateInvokableAndImmutable); return kTfLiteOk; } @@ -364,7 +368,12 @@ TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector<int>& inputs, const std::vector<int>& outputs, const char* init_data, size_t init_data_size, void* builtin_data, const TfLiteRegistration* registration, int* node_index) { - invokable_ = false; + if (state_ == kStateInvokableAndImmutable) { + ReportError(&context_, + "AddNodeWithParameters is disallowed when graph is immutable."); + return kTfLiteError; + } + state_ = kStateUninvokable; std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data, free); @@ -420,12 +429,17 @@ TfLiteStatus Interpreter::AddNodeWithParameters( TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, const std::vector<int>& dims) { + if (state_ == kStateInvokableAndImmutable) { + ReportError(&context_, + "ResizeInputTensor is disallowed when graph is immutable."); + return kTfLiteError; + } + state_ = kStateUninvokable; + // TODO(aselle): All bounds checks can be implemented as one-sided bounds // checks by casting to unsigned for efficiency. Profile before doing this. - TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); - invokable_ = false; TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); } @@ -490,7 +504,7 @@ TfLiteStatus Interpreter::Invoke() { ReportError(&context_, "Invoke called on model that is not consistent."); return kTfLiteError; } - if (!invokable_) { + if (state_ == kStateUninvokable) { ReportError(&context_, "Invoke called on model that is not ready."); return kTfLiteError; } @@ -622,6 +636,13 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const int rank, const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation) { + if (state_ == kStateInvokableAndImmutable) { + ReportError( + &context_, + "SetTensorParametersReadOnly is disallowed when graph is immutable."); + return kTfLiteError; + } + TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); // For most tensors we know exactly how much memory is necessary so we can @@ -645,7 +666,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( tensor.allocation_type = kTfLiteMmapRo; tensor.allocation = allocation; } else { - invokable_ = false; + state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast<char*>(buffer), bytes, kTfLiteMmapRo, allocation, &tensor); @@ -660,7 +681,12 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const int rank, const int* dims, TfLiteQuantizationParams quantization) { - invokable_ = false; + if (state_ == kStateInvokableAndImmutable) { + ReportError( + &context_, + "SetTensorParametersReadWrite is disallowed when graph is immutable."); + return kTfLiteError; + } TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); size_t required_bytes = 0; @@ -738,19 +764,55 @@ void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; } -TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { +TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, + bool allow_dynamic_tensors) { + if (!allow_dynamic_tensors) { + int last_execution_plan_index_prepared; + TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt( + 0, &last_execution_plan_index_prepared)); + + bool has_dynamic_tensors = true; + // Dynamic tensors exist if not all nodes can be prepared. + if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) { + // If all the nodes can be prepared, check if the last node has dynamic + // tensors. + int node_index = execution_plan_[last_execution_plan_index_prepared]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + if (!HasDynamicTensor(context_, node.outputs)) { + has_dynamic_tensors = false; + } + } + if (has_dynamic_tensors) { + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + } + // TODO(aselle): Consider if it is worth storing pointers to delegates. - // Setup additional context interface + // Setup additional context interface. context_.GetNodeAndRegistration = GetNodeAndRegistration; context_.ReplaceSubgraphsWithDelegateKernels = ReplaceSubgraphsWithDelegateKernels; context_.GetExecutionPlan = GetExecutionPlan; TfLiteStatus status = delegate->Prepare(&context_, delegate); + // Remove additional context info. SetForbiddenContextFunction(&context_.GetNodeAndRegistration); SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); SetForbiddenContextFunction(&context_.GetExecutionPlan); + + TF_LITE_ENSURE_OK(&context_, status); + + if (!allow_dynamic_tensors) { + TF_LITE_ENSURE_OK(&context_, AllocateTensors()); + TF_LITE_ENSURE(&context_, state_ == kStateInvokable || + state_ == kStateInvokableAndImmutable); + // After using a delegate which doesn't support dynamic tensors, make the + // entire graph immutable. + state_ = kStateInvokableAndImmutable; + } + return status; } |