diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-30 02:41:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 02:45:14 -0700 |
commit | 333f9c03950a1b6afb8a902b2dc3d883be490b86 (patch) | |
tree | 1c759f72f699df5078f085a517334ce8da8f1fec /tensorflow/contrib/lite/kernels/logical.cc | |
parent | 9e0b05bbc4bb88d1b34fb2147429dc4ad7bd25cd (diff) |
Implementation of logical_or.
PiperOrigin-RevId: 206549781
Diffstat (limited to 'tensorflow/contrib/lite/kernels/logical.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/logical.cc | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc new file mode 100644 index 0000000000..3dc39bf79a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace logical { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for logical op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteBool) { + context->ReportError(context, "Logical ops only support bool type."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, + const std::function<bool(bool, bool)>& func) { + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (data->requires_broadcast) { + reference_ops::BroadcastLogical( + GetTensorData<bool>(input1), GetTensorDims(input1), + GetTensorData<bool>(input2), GetTensorDims(input2), + GetTensorData<bool>(output), GetTensorDims(output), func); + } else { + reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1), + GetTensorData<bool>(input2), GetTensorDims(input2), + GetTensorData<bool>(output), GetTensorDims(output), + func); + } + + return kTfLiteOk; +} + +TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_or_func = std::logical_or<bool>(); + return LogicalImpl(context, node, logical_or_func); +} + +} // namespace +} // namespace logical + +TfLiteRegistration* Register_LOGICAL_OR() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalOrEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite |