diff options
author | 2018-06-07 19:31:38 -0700 | |
---|---|---|
committer | 2018-06-07 19:34:14 -0700 | |
commit | 99e6a86480bfb518dea59b4b25f7c9549b227587 (patch) | |
tree | fa4176f76f672b86874135dcd6e8e560067aa65f /tensorflow/contrib/lite/kernels/elementwise.cc | |
parent | a9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (diff) |
Implement Log operator.
PiperOrigin-RevId: 199735191
Diffstat (limited to 'tensorflow/contrib/lite/kernels/elementwise.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/elementwise.cc | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 0bd5046950..98c21ce9d3 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -23,7 +23,7 @@ namespace ops { namespace builtin { namespace elementwise { -TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); @@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } -TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { const float* in = GetTensorData<float>(input); const float* in_end = in + elements; float* out = output->data.f; - for (; in < in_end; in++, out++) *out = std::sin(*in); + for (; in < in_end; in++, out++) *out = float_func(*in); return kTfLiteOk; } default: { @@ -55,14 +56,28 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare, + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, elementwise::SinEval}; return &r; } +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite |