aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-13 14:48:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 14:50:40 -0700
commita3273e090f7ea8401ea283ad052350aeffa5fdc1 (patch)
tree9497eab175c8d8ed52a466e5869d33433e6333da /tensorflow/contrib/lite/interpreter.cc
parent2f7f04a7a03003e8fe345667ddf0b088032f0e03 (diff)
Variable Tensor API for TF Lite.
PiperOrigin-RevId: 200457602
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc55
1 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 2f8205444d..3287f9c4fd 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo {
const std::vector<int>& outputs() const override {
return interpreter_->outputs();
}
+ const std::vector<int>& variables() const override {
+ return interpreter_->variables();
+ }
public:
Interpreter* interpreter_;
@@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return kTfLiteOk;
}
+TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
+ TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(),
+ variables.size()));
+ variables_ = std::move(variables);
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
const int* indices, int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
@@ -370,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() {
}
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
+
if (state_ == kStateUninvokable) {
state_ = kStateInvokable;
}
@@ -378,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() {
return kTfLiteOk;
}
+// TODO(ycling): Consider to provide other functions to initialize variable
+// tensors to non-zero values.
+TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+ for (auto& tensor : tensors_) {
+ if (!tensor.is_variable) {
+ continue;
+ }
+
+ // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be
+ // allocated after the initial `PrepareOpsAndTensors()` is called.
+ TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type,
+ kTfLiteArenaRwPersistent);
+ TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
+
+ memset(tensor.data.raw, 0, tensor.bytes);
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
@@ -690,7 +720,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
- kTfLiteMmapRo, allocation, &tensor);
+ kTfLiteMmapRo, allocation, false, &tensor);
}
return kTfLiteOk;
}
@@ -701,7 +731,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
- const int* dims, TfLiteQuantizationParams quantization) {
+ const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
&context_,
@@ -719,11 +749,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
}
+
+ TfLiteAllocationType allocation_type = kTfLiteArenaRw;
+ if (type == kTfLiteString) {
+ if (is_variable) {
+ // We don't have a real use case for string variable tensor.
+ ReportError(&context_, "String variable tensor isn't supported.");
+ return kTfLiteError;
+ }
+ allocation_type = kTfLiteDynamic;
+ } else if (is_variable) {
+ allocation_type = kTfLiteArenaRwPersistent;
+ }
+
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
- /*buffer=*/nullptr, required_bytes,
- type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
- nullptr, &context_.tensors[tensor_index]);
+ /*buffer=*/nullptr, required_bytes, allocation_type,
+ nullptr, is_variable, &context_.tensors[tensor_index]);
return kTfLiteOk;
}
@@ -739,7 +781,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
TfLiteIntArray* new_size) {
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw ||
- tensor->allocation_type == kTfLiteDynamic) {
+ tensor->allocation_type == kTfLiteDynamic ||
+ tensor->allocation_type == kTfLiteArenaRwPersistent) {
if (tensor->type != kTfLiteString) {
size_t bytesRequired;
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,