aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/elementwise.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 19:31:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 19:34:14 -0700
commit99e6a86480bfb518dea59b4b25f7c9549b227587 (patch)
treefa4176f76f672b86874135dcd6e8e560067aa65f /tensorflow/contrib/lite/kernels/elementwise.cc
parenta9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (diff)
Implement Log operator.
PiperOrigin-RevId: 199735191
Diffstat (limited to 'tensorflow/contrib/lite/kernels/elementwise.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc23
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 0bd5046950..98c21ce9d3 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -23,7 +23,7 @@ namespace ops {
namespace builtin {
namespace elementwise {
-TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
@@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
-TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node,
+ float float_func(float)) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
@@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
const float* in = GetTensorData<float>(input);
const float* in_end = in + elements;
float* out = output->data.f;
- for (; in < in_end; in++, out++) *out = std::sin(*in);
+ for (; in < in_end; in++, out++) *out = float_func(*in);
return kTfLiteOk;
}
default: {
@@ -55,14 +56,28 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::sin);
+}
+
+TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::log);
+}
+
} // namespace elementwise
TfLiteRegistration* Register_SIN() {
- static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare,
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::SinEval};
return &r;
}
+TfLiteRegistration* Register_LOG() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
+ elementwise::LogEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite