aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc58
1 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 817266a471..d6d62580e2 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -40,6 +40,11 @@ struct OpData {
int diff_min = 0;
};
+struct LogSoftmaxOpData : public OpData {
+ int32_t reverse_scaling_divisor = 0;
+ int32_t reverse_scaling_right_shift = 0;
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
@@ -47,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData;
}
+void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ return new LogSoftmaxOpData;
+}
+
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
+void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
+}
+
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -205,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
+TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
+ TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
+
+ static const double kBeta = 1.0;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessLogSoftmaxScalingExp(
+ kBeta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift,
+ &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
+ data->reverse_scaling_right_shift *= -1;
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -509,6 +551,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ const LogSoftmaxOpData* data =
+ reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
@@ -517,6 +561,14 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<float>(input), GetTensorShape(input),
GetTensorData<float>(output), GetTensorShape(output));
return kTfLiteOk;
+ case kTfLiteUInt8:
+ optimized_ops::LogSoftmax(
+ GetTensorData<uint8_t>(input), GetTensorShape(input),
+ data->input_multiplier, data->input_left_shift,
+ data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorShape(output));
+ return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
input->type);
@@ -590,9 +642,9 @@ TfLiteRegistration* Register_SOFTMAX() {
}
TfLiteRegistration* Register_LOG_SOFTMAX() {
- static TfLiteRegistration r = {activations::Init, activations::Free,
- activations::GenericPrepare,
- activations::LogSoftmaxEval};
+ static TfLiteRegistration r = {
+ activations::LogSoftmaxInit, activations::LogSoftmaxFree,
+ activations::LogSoftmaxPrepare, activations::LogSoftmaxEval};
return &r;
}