aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/kernel_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-26 09:03:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 09:07:46 -0800
commitc4ace4e2abf6f19f34357e53ba4aebce5113af01 (patch)
treed2420f960052a17de64ba0a365ba49129266a627 /tensorflow/contrib/lite/kernels/kernel_util.cc
parent7fc61bfb50aac4e2d0ff9dab9d99a6001aa5cccf (diff)
Kernel utils to support broadcast add and mul.
PiperOrigin-RevId: 183397494
Diffstat (limited to 'tensorflow/contrib/lite/kernels/kernel_util.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index b0546c00cf..955e8c5764 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
#include <algorithm>
#include <cmath>
+#include <memory>
+
#include "tensorflow/contrib/lite/kernels/internal/round.h"
namespace tflite {
@@ -84,4 +87,27 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
}
}
+bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) {
+ return TfLiteIntArrayEqual(input1->dims, input2->dims);
+}
+
+TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
+ TfLiteTensor* input1,
+ TfLiteTensor* input2,
+ TfLiteIntArray** output_shape) {
+ int64_t dims1 = NumDimensions(input1);
+ int64_t dims2 = NumDimensions(input2);
+ int64_t out_dims = std::max(dims1, dims2);
+ std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
+ TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree);
+ for (int i = 0; i < out_dims; ++i) {
+ int64_t d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
+ int64_t d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
+ TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1);
+ shape->data[out_dims - i - 1] = std::max(d1, d2);
+ }
+ *output_shape = shape.release();
+ return kTfLiteOk;
+}
+
} // namespace tflite