aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc14
2 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 87b17c338e..8e47e2375e 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -128,6 +128,11 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
}
TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) {
+ // Grow the size of `allocs_` if necessary. This allows allocating temporary
+ // tensors in op's `prepare` function.
+ TF_LITE_ENSURE(context_, graph_info_->num_tensors() >= allocs_.size());
+ allocs_.resize(graph_info_->num_tensors());
+
TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node));
TF_LITE_ENSURE_STATUS(Commit());
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 66d2c04bba..495910aab6 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -51,11 +51,13 @@ enum KernelType {
kCblasOptimized,
};
+const int kTensorNotAllocated = -1;
+
struct OpData {
// IDs are the arbitrary identifiers used by TF Lite to identify and access
// memory buffers.
- int im2col_id;
- int hwcn_weights_id;
+ int im2col_id = kTensorNotAllocated;
+ int hwcn_weights_id = kTensorNotAllocated;
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
@@ -80,8 +82,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to use as scratch space for im2col, and
// to carry information from Prepare() to Eval().
auto* data = new OpData;
- context->AddTensors(context, 1, &data->im2col_id);
- context->AddTensors(context, 1, &data->hwcn_weights_id);
gemm_support::IncrementUsageCounter(context);
return data;
}
@@ -219,10 +219,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int temporaries_count = 0;
if (data->need_im2col) {
data->im2col_index = temporaries_count;
+ if (data->im2col_id == kTensorNotAllocated) {
+ context->AddTensors(context, 1, &data->im2col_id);
+ }
++temporaries_count;
}
if (data->need_hwcn_weights) {
data->hwcn_weights_index = temporaries_count;
+ if (data->hwcn_weights_id == kTensorNotAllocated) {
+ context->AddTensors(context, 1, &data->hwcn_weights_id);
+ }
++temporaries_count;
}