aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/strided_slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 15:04:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 15:08:05 -0800
commit1209fc3aaa629e135823f11da0a3316e8a7e8e78 (patch)
tree709f853aca5bdb4b5197e799c165f75219e36f54 /tensorflow/contrib/lite/kernels/strided_slice.cc
parent0181aa321c8656cb574506476ca5f749061c0e87 (diff)
Optimize memory allocation of TFLite StridedSlice Op for constant tensors.
PiperOrigin-RevId: 183899156
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc183
1 files changed, 89 insertions, 94 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index c510ee3b9f..c4ffdf79d3 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -57,63 +57,6 @@ struct StridedSliceContext {
int dims;
};
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
- StridedSliceContext op_context(context, node);
-
- // Ensure validity of input tensor and its dimension
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
- TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
- // Only INT32 begin/end/strides are supported
- // TODO(soroosh) add support for INT64
- TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
- TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
- TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
- TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
- "StridedSlice op only supports 1D-4D input arrays.");
-
- // TODO(soroosh): add the following missing functionalities
- TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
- "ellipsis_mask is not implemented yet.");
- TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
- "new_axis_mask is not implemented yet.");
-
- // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
- op_context.output->allocation_type = kTfLiteDynamic;
- return kTfLiteOk;
-} // namespace strided_slice
-
-// TODO(soroosh): consolidate with BytesRequired in interpreter.h
-TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type,
- const int* dims, int dims_size, size_t* bytes) {
- // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
- // MultiplyWithoutOverflow.
- TF_LITE_ENSURE(context, bytes != nullptr);
- size_t count = 1;
- for (int k = 0; k < dims_size; k++) count *= dims[k];
- switch (type) {
- case kTfLiteFloat32:
- *bytes = sizeof(float) * count;
- break;
- case kTfLiteInt32:
- *bytes = sizeof(int32_t) * count;
- break;
- case kTfLiteUInt8:
- *bytes = sizeof(uint8_t) * count;
- break;
- case kTfLiteInt64:
- *bytes = sizeof(int64_t) * count;
- break;
- default:
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Reverse order of bits in the mask to match the expected order in kernel
inline int ReverseMaskBits(int mask, int num_dimensions) {
int out = 0;
@@ -144,43 +87,44 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
std::min(std::max(index, -dim), dim - 1), dim));
}
-template <KernelType kernel_type>
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- StridedSliceContext op_context(context, node);
+inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
+ const int dim = op_context->input->dims->data[idx];
+ const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
+ return op_context->params->begin_mask & (1 << idx)
+ ? pos_stride ? 0 : dim - 1
+ : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim,
+ pos_stride);
+}
- std::vector<int> starts;
- std::vector<int> stops;
- std::vector<int> strides;
+inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) {
+ const int dim = op_context->input->dims->data[idx];
+ const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
+ return op_context->params->end_mask & (1 << idx)
+ ? pos_stride ? dim : -1
+ : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim,
+ pos_stride);
+}
+
+// Processes the indexing tensors (begin, end and strides) to resize the
+// output tensor. This function is callable from both Prepare() and Eval() as
+// long as the caller ensures the indexing tensors are present.
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ StridedSliceContext* op_context) {
std::vector<int> output_shape_vector;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- int dim = op_context.input->dims->data[idx];
- int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
+ for (int idx = op_context->dims - 1; idx >= 0; --idx) {
+ int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
- bool pos_stride = stride > 0;
-
- int32_t begin =
- op_context.params->begin_mask & (1 << idx)
- ? pos_stride ? 0 : dim - 1
- : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim,
- pos_stride);
- int32_t end =
- op_context.params->end_mask & (1 << idx)
- ? pos_stride ? dim : -1
- : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim,
- pos_stride);
+
+ int32_t begin = GetBeginValueAtIndex(op_context, idx);
+ int32_t end = GetEndValueAtIndex(op_context, idx);
// This is valid for both positive and negative strides
int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
-
- if (!(op_context.params->shrink_axis_mask & (1 << idx))) {
+ if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
output_shape_vector.push_back(dim_shape);
}
-
- starts.emplace_back(begin);
- stops.emplace_back(end);
- strides.emplace_back(stride);
}
TfLiteIntArray* output_shape =
@@ -189,22 +133,73 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
output_shape->data);
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, op_context->output, output_shape));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ StridedSliceContext op_context(context, node);
+
+ // Ensure validity of input tensor and its dimension
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ // Only INT32 begin/end/strides are supported
+ // TODO(soroosh) add support for INT64
+ TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
+ TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
+ "StridedSlice op only supports 1D-4D input arrays.");
+
+ // TODO(soroosh): add the following missing functionalities
+ TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
+ "ellipsis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
+ "new_axis_mask is not implemented yet.");
+
+ // Postpone allocation of output if any of the indexing tensors is not
+ // constant
+ if (!(IsConstantTensor(op_context.begin) &&
+ IsConstantTensor(op_context.end) &&
+ IsConstantTensor(op_context.strides))) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, &op_context);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ StridedSliceContext op_context(context, node);
+
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+ }
+
+ std::vector<int32_t> starts;
+ std::vector<int32_t> stops;
+ std::vector<int32_t> strides;
+
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ starts.emplace_back(GetBeginValueAtIndex(&op_context, idx));
+ stops.emplace_back(GetEndValueAtIndex(&op_context, idx));
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- TF_LITE_ENSURE_STATUS(
- context->ResizeTensor(context, op_context.output, output_shape));
-
- size_t required_bytes;
- TF_LITE_ENSURE_OK(
- context,
- BytesRequired(context, op_context.output->type, output_shape->data,
- output_shape->size, &required_bytes));
- TfLiteTensorRealloc(required_bytes, op_context.output);
-
op_context.params->begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
op_context.params->end_mask =