diff options
-rw-r--r-- | tensorflow/contrib/lite/arena_planner.cc | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 14 |
2 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 87b17c338e..8e47e2375e 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -128,6 +128,11 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + // Grow the size of `allocs_` if necessary. This allows allocating temporary + // tensors in op's `prepare` function. + TF_LITE_ENSURE(context_, graph_info_->num_tensors() >= allocs_.size()); + allocs_.resize(graph_info_->num_tensors()); + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); TF_LITE_ENSURE_STATUS(Commit()); diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 66d2c04bba..495910aab6 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -51,11 +51,13 @@ enum KernelType { kCblasOptimized, }; +const int kTensorNotAllocated = -1; + struct OpData { // IDs are the arbitrary identifiers used by TF Lite to identify and access // memory buffers. - int im2col_id; - int hwcn_weights_id; + int im2col_id = kTensorNotAllocated; + int hwcn_weights_id = kTensorNotAllocated; TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can @@ -80,8 +82,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to use as scratch space for im2col, and // to carry information from Prepare() to Eval(). auto* data = new OpData; - context->AddTensors(context, 1, &data->im2col_id); - context->AddTensors(context, 1, &data->hwcn_weights_id); gemm_support::IncrementUsageCounter(context); return data; } @@ -219,10 +219,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int temporaries_count = 0; if (data->need_im2col) { data->im2col_index = temporaries_count; + if (data->im2col_id == kTensorNotAllocated) { + context->AddTensors(context, 1, &data->im2col_id); + } ++temporaries_count; } if (data->need_hwcn_weights) { data->hwcn_weights_index = temporaries_count; + if (data->hwcn_weights_id == kTensorNotAllocated) { + context->AddTensors(context, 1, &data->hwcn_weights_id); + } ++temporaries_count; } |