diff options
author | Jared Duke <jdduke@google.com> | 2018-08-14 15:20:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 15:26:24 -0700 |
commit | b1c9e5d03eb68623e8d2fc5d29bb408a48c1c2da (patch) | |
tree | 10ac2358bd12fdedf0279e072a57ccce9e394cfc /tensorflow/contrib/lite/kernels/activations.cc | |
parent | f8c946ceb9fcacd93c2640c65b1a5b74a38002f8 (diff) |
Add quantized support to LOG_SOFTMAX
PiperOrigin-RevId: 208723709
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/activations.cc | 58 |
1 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 817266a471..d6d62580e2 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -40,6 +40,11 @@ struct OpData { int diff_min = 0; }; +struct LogSoftmaxOpData : public OpData { + int32_t reverse_scaling_divisor = 0; + int32_t reverse_scaling_right_shift = 0; +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to @@ -47,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return new OpData; } +void* LogSoftmaxInit(TfLiteContext* context, const char* buffer, + size_t length) { + return new LogSoftmaxOpData; +} + void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast<OpData*>(buffer); } +void LogSoftmaxFree(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<LogSoftmaxOpData*>(buffer); +} + TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -205,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); + + static const double kBeta = 1.0; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + kBeta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift, + &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift); + data->reverse_scaling_right_shift *= -1; + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -509,6 +551,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + const LogSoftmaxOpData* data = + reinterpret_cast<LogSoftmaxOpData*>(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -517,6 +561,14 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { GetTensorData<float>(input), GetTensorShape(input), GetTensorData<float>(output), GetTensorShape(output)); return kTfLiteOk; + case kTfLiteUInt8: + optimized_ops::LogSoftmax( + GetTensorData<uint8_t>(input), GetTensorShape(input), + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, GetTensorData<uint8_t>(output), + GetTensorShape(output)); + return kTfLiteOk; default: context->ReportError(context, "Only float32 supported currently., got %d", input->type); @@ -590,9 +642,9 @@ TfLiteRegistration* Register_SOFTMAX() { } TfLiteRegistration* Register_LOG_SOFTMAX() { - static TfLiteRegistration r = {activations::Init, activations::Free, - activations::GenericPrepare, - activations::LogSoftmaxEval}; + static TfLiteRegistration r = { + activations::LogSoftmaxInit, activations::LogSoftmaxFree, + activations::LogSoftmaxPrepare, activations::LogSoftmaxEval}; return &r; } |