aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unpack.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unpack.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc130
1 files changed, 130 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
new file mode 100644
index 0000000000..9ff06f8331
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unpack {
+namespace {
+
+constexpr int kInputTensor = 0;
+
+// Op data for unpack op.
+struct OpData {
+ int num;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
+ TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) > data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+
+ const TfLiteIntArray* input_shape = input->dims;
+ // Num should be equal to the shape[axis].
+ // Resize outputs. rank will be R - 1.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
+ int o = 0;
+ for (int index = 0; index < NumDimensions(input); ++index) {
+ if (index != data->axis) {
+ output_shape->data[o++] = input_shape->data[index];
+ }
+ }
+
+ TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]);
+ for (int i = 0; i < data->num; ++i) {
+ TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->type, input->type);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output, copied_output_shape));
+ }
+
+ TfLiteIntArrayFree(output_shape);
+ return kTfLiteOk;
+}
+
+template <typename T>
+void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input, int output_count, int axis) {
+ VectorOfTensors<T> all_outputs(*context, *node->outputs);
+ reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
+ NumDimensions(input), output_count,
+ all_outputs.data(), **all_outputs.dims());
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ UnpackImpl<float>(context, node, input, data->num, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace unpack
+
+TfLiteRegistration* Register_UNPACK() {
+ static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare,
+ unpack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite