diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-30 15:04:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-30 15:08:05 -0800 |
commit | 1209fc3aaa629e135823f11da0a3316e8a7e8e78 (patch) | |
tree | 709f853aca5bdb4b5197e799c165f75219e36f54 /tensorflow/contrib/lite/kernels/strided_slice.cc | |
parent | 0181aa321c8656cb574506476ca5f749061c0e87 (diff) |
Optimize memory allocation of TFLite StridedSlice Op for constant tensors.
PiperOrigin-RevId: 183899156
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice.cc | 183 |
1 files changed, 89 insertions, 94 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index c510ee3b9f..c4ffdf79d3 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -57,63 +57,6 @@ struct StridedSliceContext { int dims; }; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - StridedSliceContext op_context(context, node); - - // Ensure validity of input tensor and its dimension - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - // Only INT32 begin/end/strides are supported - // TODO(soroosh) add support for INT64 - TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); - TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, - "StridedSlice op only supports 1D-4D input arrays."); - - // TODO(soroosh): add the following missing functionalities - TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, - "ellipsis_mask is not implemented yet."); - TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, - "new_axis_mask is not implemented yet."); - - // TODO(soroosh): optimize for constant tensors to do allocation in Prepare - op_context.output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; -} // namespace strided_slice - -// TODO(soroosh): consolidate with BytesRequired in interpreter.h -TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, - const int* dims, int dims_size, size_t* bytes) { - // TODO(aselle): Check for overflow here using overflow.h in TensorFlow - // MultiplyWithoutOverflow. - TF_LITE_ENSURE(context, bytes != nullptr); - size_t count = 1; - for (int k = 0; k < dims_size; k++) count *= dims[k]; - switch (type) { - case kTfLiteFloat32: - *bytes = sizeof(float) * count; - break; - case kTfLiteInt32: - *bytes = sizeof(int32_t) * count; - break; - case kTfLiteUInt8: - *bytes = sizeof(uint8_t) * count; - break; - case kTfLiteInt64: - *bytes = sizeof(int64_t) * count; - break; - default: - return kTfLiteError; - } - return kTfLiteOk; -} - // Reverse order of bits in the mask to match the expected order in kernel inline int ReverseMaskBits(int mask, int num_dimensions) { int out = 0; @@ -144,43 +87,44 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { std::min(std::max(index, -dim), dim - 1), dim)); } -template <KernelType kernel_type> -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - StridedSliceContext op_context(context, node); +inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0; + return op_context->params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim, + pos_stride); +} - std::vector<int> starts; - std::vector<int> stops; - std::vector<int> strides; +inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0; + return op_context->params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim, + pos_stride); +} + +// Processes the indexing tensors (begin, end and strides) to resize the +// output tensor. This function is callable from both Prepare() and Eval() as +// long as the caller ensures the indexing tensors are present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + StridedSliceContext* op_context) { std::vector<int> output_shape_vector; - for (int idx = op_context.dims - 1; idx >= 0; --idx) { - int dim = op_context.input->dims->data[idx]; - int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx]; + for (int idx = op_context->dims - 1; idx >= 0; --idx) { + int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx]; TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); - bool pos_stride = stride > 0; - - int32_t begin = - op_context.params->begin_mask & (1 << idx) - ? pos_stride ? 0 : dim - 1 - : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim, - pos_stride); - int32_t end = - op_context.params->end_mask & (1 << idx) - ? pos_stride ? dim : -1 - : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim, - pos_stride); + + int32_t begin = GetBeginValueAtIndex(op_context, idx); + int32_t end = GetEndValueAtIndex(op_context, idx); // This is valid for both positive and negative strides int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride)); dim_shape = dim_shape < 0 ? 0 : dim_shape; - - if (!(op_context.params->shrink_axis_mask & (1 << idx))) { + if (!(op_context->params->shrink_axis_mask & (1 << idx))) { output_shape_vector.push_back(dim_shape); } - - starts.emplace_back(begin); - stops.emplace_back(end); - strides.emplace_back(stride); } TfLiteIntArray* output_shape = @@ -189,22 +133,73 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), output_shape->data); + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context->output, output_shape)); + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(op_context.begin) && + IsConstantTensor(op_context.end) && + IsConstantTensor(op_context.strides))) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template <KernelType kernel_type> +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + StridedSliceContext op_context(context, node); + + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } + + std::vector<int32_t> starts; + std::vector<int32_t> stops; + std::vector<int32_t> strides; + + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); + stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); + strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]); + } + for (int i = op_context.dims; i < kMaxDim; i++) { starts.emplace_back(0); stops.emplace_back(1); strides.emplace_back(1); } - TF_LITE_ENSURE_STATUS( - context->ResizeTensor(context, op_context.output, output_shape)); - - size_t required_bytes; - TF_LITE_ENSURE_OK( - context, - BytesRequired(context, op_context.output->type, output_shape->data, - output_shape->size, &required_bytes)); - TfLiteTensorRealloc(required_bytes, op_context.output); - op_context.params->begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); op_context.params->end_mask = |