aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/pack.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 12:23:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 12:28:35 -0700
commitd0377209b9a39a3e00483609e16f3d8efe72b3f9 (patch)
tree816822c145db303e2d80ff5607f83cb40cc1b1be /tensorflow/contrib/lite/kernels/pack.cc
parent1a6d7f5acd50ed23c38f14e11e563f771a596656 (diff)
Added uint8 support for Pack.
PiperOrigin-RevId: 209463575
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pack.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc47
1 files changed, 24 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index bb3416f6a6..cc326a7d51 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -27,24 +27,9 @@ namespace {
constexpr int kOutputTensor = 0;
-// Op data for pack op.
-struct OpData {
- int values_count;
- int axis;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* data = new OpData;
- data->axis = 0;
- return data;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<OpData*>(buffer);
-}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -54,9 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
// TODO(renjieliu): Support negative axis.
TF_LITE_ENSURE(context, data->axis >= 0);
- if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) {
+ if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
+ input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int16/int32.");
return kTfLiteError;
}
// Make sure all inputs have the same shape and type.
@@ -82,6 +69,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, output->type, input0->type);
+ // Guarantee input/output quantization params match as we do not support
+ // packing quantized tensors.
+ for (int i = 0; i < data->values_count; i++) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ }
+
return context->ResizeTensor(context, output, output_shape);
}
@@ -95,7 +91,8 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
@@ -103,13 +100,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PackImpl<float>(context, node, output, data->values_count, data->axis);
break;
}
+ case kTfLiteUInt8: {
+ PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
case kTfLiteInt32: {
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
break;
}
default: {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int32.");
return kTfLiteError;
}
}
@@ -121,8 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace pack
TfLiteRegistration* Register_PACK() {
- static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare,
- pack::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
return &r;
}