diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-26 09:03:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-26 09:07:46 -0800 |
commit | c4ace4e2abf6f19f34357e53ba4aebce5113af01 (patch) | |
tree | d2420f960052a17de64ba0a365ba49129266a627 /tensorflow/contrib/lite/kernels/kernel_util.cc | |
parent | 7fc61bfb50aac4e2d0ff9dab9d99a6001aa5cccf (diff) |
Kernel utils to support broadcast add and mul.
PiperOrigin-RevId: 183397494
Diffstat (limited to 'tensorflow/contrib/lite/kernels/kernel_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/kernel_util.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index b0546c00cf..955e8c5764 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/kernel_util.h" + #include <algorithm> #include <cmath> +#include <memory> + #include "tensorflow/contrib/lite/kernels/internal/round.h" namespace tflite { @@ -84,4 +87,27 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, } } +bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) { + return TfLiteIntArrayEqual(input1->dims, input2->dims); +} + +TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, + TfLiteTensor* input1, + TfLiteTensor* input2, + TfLiteIntArray** output_shape) { + int64_t dims1 = NumDimensions(input1); + int64_t dims2 = NumDimensions(input2); + int64_t out_dims = std::max(dims1, dims2); + std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape( + TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree); + for (int i = 0; i < out_dims; ++i) { + int64_t d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1); + int64_t d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1); + TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1); + shape->data[out_dims - i - 1] = std::max(d1, d2); + } + *output_shape = shape.release(); + return kTfLiteOk; +} + } // namespace tflite |