aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-19 14:27:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 14:37:01 -0700
commitff43dff34ab525dd333128c73ebfb0f9723c34c0 (patch)
tree395c315428a0f084cfce47a562154ec15474e047 /tensorflow/contrib/lite/interpreter.cc
parente613e0844a95814457f3530eedb9baf812cf1e87 (diff)
TFLite Delegate: Add an `allow_dynamic_tensors` parameter.
PiperOrigin-RevId: 189641833
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc80
1 files changed, 71 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index cee57bba5e..937c185b0a 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -356,7 +356,11 @@ TfLiteStatus Interpreter::AllocateTensors() {
}
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
- invokable_ = true;
+ if (state_ == kStateUninvokable) {
+ state_ = kStateInvokable;
+ }
+ TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+ state_ == kStateInvokableAndImmutable);
return kTfLiteOk;
}
@@ -364,7 +368,12 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration, int* node_index) {
- invokable_ = false;
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(&context_,
+ "AddNodeWithParameters is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+ state_ = kStateUninvokable;
std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
free);
@@ -420,12 +429,17 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
const std::vector<int>& dims) {
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(&context_,
+ "ResizeInputTensor is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+ state_ = kStateUninvokable;
+
// TODO(aselle): All bounds checks can be implemented as one-sided bounds
// checks by casting to unsigned for efficiency. Profile before doing this.
-
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
- invokable_ = false;
TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
}
@@ -490,7 +504,7 @@ TfLiteStatus Interpreter::Invoke() {
ReportError(&context_, "Invoke called on model that is not consistent.");
return kTfLiteError;
}
- if (!invokable_) {
+ if (state_ == kStateUninvokable) {
ReportError(&context_, "Invoke called on model that is not ready.");
return kTfLiteError;
}
@@ -622,6 +636,13 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const int rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(
+ &context_,
+ "SetTensorParametersReadOnly is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
@@ -645,7 +666,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation;
} else {
- invokable_ = false;
+ state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
@@ -660,7 +681,12 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const int rank,
const int* dims, TfLiteQuantizationParams quantization) {
- invokable_ = false;
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(
+ &context_,
+ "SetTensorParametersReadWrite is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
size_t required_bytes = 0;
@@ -738,19 +764,55 @@ void Interpreter::SetNumThreads(int num_threads) {
context_.recommended_num_threads = num_threads;
}
-TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
+TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
+ bool allow_dynamic_tensors) {
+ if (!allow_dynamic_tensors) {
+ int last_execution_plan_index_prepared;
+ TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt(
+ 0, &last_execution_plan_index_prepared));
+
+ bool has_dynamic_tensors = true;
+ // Dynamic tensors exist if not all nodes can be prepared.
+ if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) {
+ // If all the nodes can be prepared, check if the last node has dynamic
+ // tensors.
+ int node_index = execution_plan_[last_execution_plan_index_prepared];
+ TfLiteNode& node = nodes_and_registration_[node_index].first;
+ if (!HasDynamicTensor(context_, node.outputs)) {
+ has_dynamic_tensors = false;
+ }
+ }
+ if (has_dynamic_tensors) {
+ ReportError(&context_, "Attempting to resize a fixed-size tensor.");
+ return kTfLiteError;
+ }
+ }
+
// TODO(aselle): Consider if it is worth storing pointers to delegates.
- // Setup additional context interface
+ // Setup additional context interface.
context_.GetNodeAndRegistration = GetNodeAndRegistration;
context_.ReplaceSubgraphsWithDelegateKernels =
ReplaceSubgraphsWithDelegateKernels;
context_.GetExecutionPlan = GetExecutionPlan;
TfLiteStatus status = delegate->Prepare(&context_, delegate);
+
// Remove additional context info.
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
SetForbiddenContextFunction(&context_.GetExecutionPlan);
+
+ TF_LITE_ENSURE_OK(&context_, status);
+
+ if (!allow_dynamic_tensors) {
+ TF_LITE_ENSURE_OK(&context_, AllocateTensors());
+ TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+ state_ == kStateInvokableAndImmutable);
+ // After using a delegate which doesn't support dynamic tensors, make the
+ // entire graph immutable.
+ state_ = kStateInvokableAndImmutable;
+ }
+
return status;
}