diff options
author | 2018-08-20 12:23:46 -0700 | |
---|---|---|
committer | 2018-08-20 12:28:35 -0700 | |
commit | d0377209b9a39a3e00483609e16f3d8efe72b3f9 (patch) | |
tree | 816822c145db303e2d80ff5607f83cb40cc1b1be /tensorflow/contrib/lite/kernels/pack.cc | |
parent | 1a6d7f5acd50ed23c38f14e11e563f771a596656 (diff) |
Added uint8 support for Pack.
PiperOrigin-RevId: 209463575
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pack.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/pack.cc | 47 |
1 files changed, 24 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc index bb3416f6a6..cc326a7d51 100644 --- a/tensorflow/contrib/lite/kernels/pack.cc +++ b/tensorflow/contrib/lite/kernels/pack.cc @@ -27,24 +27,9 @@ namespace { constexpr int kOutputTensor = 0; -// Op data for pack op. -struct OpData { - int values_count; - int axis; -}; - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* data = new OpData; - data->axis = 0; - return data; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast<OpData*>(buffer); -} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - const OpData* data = reinterpret_cast<OpData*>(node->builtin_data); + const TfLitePackParams* data = + reinterpret_cast<TfLitePackParams*>(node->builtin_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -54,9 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis); // TODO(renjieliu): Support negative axis. TF_LITE_ENSURE(context, data->axis >= 0); - if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) { + if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 && + input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) { context->ReportError(context, - "Currently pack only supports int32 and float32."); + "Currently pack only supports " + "float32/uint8/int16/int32."); return kTfLiteError; } // Make sure all inputs have the same shape and type. @@ -82,6 +69,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, output->type, input0->type); + // Guarantee input/output quantization params match as we do not support + // packing quantized tensors. + for (int i = 0; i < data->values_count; i++) { + const TfLiteTensor* input = GetInput(context, node, i); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + } + return context->ResizeTensor(context, output, output_shape); } @@ -95,7 +91,8 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const OpData* data = reinterpret_cast<OpData*>(node->builtin_data); + const TfLitePackParams* data = + reinterpret_cast<TfLitePackParams*>(node->builtin_data); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); switch (output->type) { @@ -103,13 +100,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PackImpl<float>(context, node, output, data->values_count, data->axis); break; } + case kTfLiteUInt8: { + PackImpl<uint8_t>(context, node, output, data->values_count, data->axis); + break; + } case kTfLiteInt32: { PackImpl<int32_t>(context, node, output, data->values_count, data->axis); break; } default: { context->ReportError(context, - "Currently pack only supports int32 and float32."); + "Currently pack only supports " + "float32/uint8/int32."); return kTfLiteError; } } @@ -121,8 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace pack TfLiteRegistration* Register_PACK() { - static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare, - pack::Eval}; + static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval}; return &r; } |