aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD196
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc54
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc85
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/div_test.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD22
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h133
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h239
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc71
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h797
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h234
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h1039
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h120
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc131
-rw-r--r--tensorflow/contrib/lite/kernels/pack_test.cc120
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/pow_test.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc115
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc388
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc69
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc98
-rw-r--r--tensorflow/contrib/lite/kernels/sub_test.cc58
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc1
53 files changed, 3040 insertions, 1593 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index edce73989c..c224132cae 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -12,7 +12,10 @@ tf_cc_test(
name = "optional_tensor_test",
size = "small",
srcs = ["optional_tensor_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -55,6 +58,7 @@ cc_library(
}),
deps = [
":op_macros",
+ "//tensorflow/contrib/lite:arena_planner",
"//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite/kernels/internal:optimized",
],
@@ -112,7 +116,10 @@ tf_cc_test(
name = "kernel_util_test",
size = "small",
srcs = ["kernel_util_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":kernel_util",
"//tensorflow/contrib/lite/testing:util",
@@ -124,6 +131,7 @@ tf_cc_test(
name = "test_util_test",
size = "small",
srcs = ["test_util_test.cc"],
+ tags = ["no_oss"],
deps = [
":test_util",
"//tensorflow/contrib/lite/testing:util",
@@ -168,6 +176,7 @@ cc_library(
"mfcc.cc",
"mul.cc",
"neg.cc",
+ "pack.cc",
"pad.cc",
"pooling.cc",
"pow.cc",
@@ -232,7 +241,10 @@ tf_cc_test(
name = "audio_spectrogram_test",
size = "small",
srcs = ["audio_spectrogram_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -246,7 +258,10 @@ tf_cc_test(
name = "mfcc_test",
size = "small",
srcs = ["mfcc_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -260,7 +275,10 @@ tf_cc_test(
name = "detection_postprocess_test",
size = "small",
srcs = ["detection_postprocess_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -301,6 +319,7 @@ tf_cc_test(
size = "small",
srcs = ["arg_min_max_test.cc"],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -315,7 +334,10 @@ tf_cc_test(
name = "div_test",
size = "small",
srcs = ["div_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -328,7 +350,10 @@ tf_cc_test(
name = "sub_test",
size = "small",
srcs = ["sub_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -341,7 +366,10 @@ tf_cc_test(
name = "transpose_test",
size = "small",
srcs = ["transpose_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -356,7 +384,10 @@ tf_cc_test(
name = "space_to_batch_nd_test",
size = "small",
srcs = ["space_to_batch_nd_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -369,7 +400,10 @@ tf_cc_test(
name = "batch_to_space_nd_test",
size = "small",
srcs = ["batch_to_space_nd_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -382,7 +416,10 @@ tf_cc_test(
name = "cast_test",
size = "small",
srcs = ["cast_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -435,7 +472,10 @@ tf_cc_test(
name = "dequantize_test",
size = "small",
srcs = ["dequantize_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -462,7 +502,10 @@ tf_cc_test(
name = "bidirectional_sequence_lstm_test",
size = "small",
srcs = ["bidirectional_sequence_lstm_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -475,7 +518,10 @@ tf_cc_test(
name = "floor_test",
size = "small",
srcs = ["floor_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -488,7 +534,10 @@ tf_cc_test(
name = "elementwise_test",
size = "small",
srcs = ["elementwise_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -501,7 +550,10 @@ tf_cc_test(
name = "unidirectional_sequence_lstm_test",
size = "small",
srcs = ["unidirectional_sequence_lstm_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -515,6 +567,7 @@ tf_cc_test(
size = "small",
srcs = ["bidirectional_sequence_rnn_test.cc"],
tags = [
+ "no_oss",
"tflite_not_portable",
],
deps = [
@@ -529,7 +582,10 @@ tf_cc_test(
name = "unidirectional_sequence_rnn_test",
size = "small",
srcs = ["unidirectional_sequence_rnn_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -555,7 +611,10 @@ tf_cc_test(
name = "exp_test",
size = "small",
srcs = ["exp_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -568,7 +627,10 @@ tf_cc_test(
name = "fake_quant_test",
size = "small",
srcs = ["fake_quant_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -581,7 +643,10 @@ tf_cc_test(
name = "maximum_minimum_test",
size = "small",
srcs = ["maximum_minimum_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -594,7 +659,10 @@ tf_cc_test(
name = "reduce_test",
size = "small",
srcs = ["reduce_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -620,7 +688,10 @@ tf_cc_test(
name = "pad_test",
size = "small",
srcs = ["pad_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -646,7 +717,10 @@ tf_cc_test(
name = "gather_test",
size = "small",
srcs = ["gather_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -660,7 +734,10 @@ tf_cc_test(
name = "topk_v2_test",
size = "small",
srcs = ["topk_v2_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -781,7 +858,10 @@ tf_cc_test(
name = "log_softmax_test",
size = "small",
srcs = ["log_softmax_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -862,7 +942,10 @@ tf_cc_test(
name = "split_test",
size = "small",
srcs = ["split_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -875,7 +958,10 @@ tf_cc_test(
name = "squeeze_test",
size = "small",
srcs = ["squeeze_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -888,7 +974,10 @@ tf_cc_test(
name = "strided_slice_test",
size = "small",
srcs = ["strided_slice_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -901,7 +990,10 @@ tf_cc_test(
name = "tile_test",
size = "small",
srcs = ["tile_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -918,6 +1010,7 @@ tf_cc_test(
"comparisons_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -932,7 +1025,10 @@ tf_cc_test(
name = "neg_test",
size = "small",
srcs = ["neg_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -948,6 +1044,7 @@ tf_cc_test(
"select_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -965,6 +1062,7 @@ tf_cc_test(
"slice_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -979,12 +1077,14 @@ tf_cc_test(
name = "transpose_conv_test",
size = "small",
srcs = ["transpose_conv_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
- "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
@@ -993,7 +1093,10 @@ tf_cc_test(
name = "expand_dims_test",
size = "small",
srcs = ["expand_dims_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -1007,7 +1110,10 @@ tf_cc_test(
name = "sparse_to_dense_test",
size = "small",
srcs = ["sparse_to_dense_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -1021,7 +1127,10 @@ tf_cc_test(
name = "shape_test",
size = "small",
srcs = ["shape_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -1035,6 +1144,23 @@ tf_cc_test(
name = "pow_test",
size = "small",
srcs = ["pow_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pack_test",
+ size = "small",
+ srcs = ["pack_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 99f81c4a8a..6e13b8c667 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -186,8 +185,8 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
- TF_LITE_ENSURE(context,
- NumDimensions(input) == 2 || NumDimensions(input) == 4);
+ const int num_dims = NumDimensions(input);
+ TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -365,13 +364,9 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// Takes a 2D tensor and perform softmax along the second dimension.
-void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
- TfLiteSoftmaxParams* params) {
- const int batch_size = input->dims->data[0];
- const int input_size = input->dims->data[1];
- float* in = input->data.f;
- float* out = output->data.f;
+// Performs softmax along the input of size (input_size * batch_size).
+void Softmax(const float* in, const int input_size, const int batch_size,
+ const float beta, float* out) {
TF_LITE_ASSERT(input_size > 0);
// For each batch
@@ -385,7 +380,7 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
// Compute the normalized sum of exps.
float exp_sum = 0.0;
for (int i = 0; i < input_size; i++) {
- out[i] = std::exp((in[i] - max_coeff) * params->beta);
+ out[i] = std::exp((in[i] - max_coeff) * beta);
exp_sum += out[i];
}
@@ -401,6 +396,33 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
}
}
+// Takes a 1D tensor and performs softmax along it.
+void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int input_size = input->dims->data[0];
+ Softmax(input->data.f, input_size, 1, params->beta, output->data.f);
+}
+
+// Takes a 2D tensor and perform softmax along the last dimension.
+void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
+}
+
+void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 1D
+ // tensor is 4D in a special way. We will convert a (Y) shape into a (1,
+ // 1, 1, Y) shape.
+ const int input_size = input->dims->data[0];
+ optimized_ops::Softmax(
+ GetTensorData<uint8_t>(input), GetTensorShape({1, 1, 1, input_size}),
+ data->input_multiplier, data->input_left_shift, data->diff_min,
+ GetTensorData<uint8_t>(output), GetTensorShape({1, 1, 1, input_size}));
+}
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -443,6 +465,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
// dimensions.
switch (input->type) {
case kTfLiteFloat32: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 2) {
Softmax2DFloat(input, output, params);
return kTfLiteOk;
@@ -452,11 +478,15 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
context->ReportError(
- context, "Only 2D and 4D tensors supported currently, got %dD.",
+ context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
}
case kTfLiteUInt8: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 2) {
Softmax2DQuantized(input, output, params, data);
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 587e1303da..083cdf78d7 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,29 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax1D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {8}});
+ m.SetInput({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax1D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {8}, -10, 10});
+ m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
+ 0.13281, 0.07813, 0.26563, 0.10938},
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax2D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {2, 4}});
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index f44d531cbf..af9b5c7013 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -110,15 +110,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(
real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
- data->input1_shift *= -1;
QuantizeMultiplierSmallerThanOneExp(
real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
- data->input2_shift *= -1;
QuantizeMultiplierSmallerThanOneExp(
real_output_multiplier, &data->output_multiplier, &data->output_shift);
- data->output_shift *= -1;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
@@ -152,14 +149,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
CheckedLog2(output->params.scale, &output_scale_log2_rounded);
TF_LITE_ENSURE(context, output_scale_is_pot);
- data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded;
- data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded;
+ data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded;
+ data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded;
// Shifting of one input is supported. The graph quantization should ensure
// that the other input matches the output.
TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0);
- TF_LITE_ENSURE(context, data->input1_shift >= 0);
- TF_LITE_ENSURE(context, data->input2_shift >= 0);
+ TF_LITE_ENSURE(context, data->input1_shift <= 0);
+ TF_LITE_ENSURE(context, data->input2_shift <= 0);
CalculateActivationRangeQuantized(context, params->activation, output,
&data->output_activation_min,
@@ -173,24 +170,27 @@ template <KernelType kernel_type>
void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_ADD(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_ADD(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(reference_ops, Add, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(optimized_ops, Add, int32_t);
}
@@ -198,13 +198,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd, float);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(reference_ops, Add, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd, float);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(optimized_ops, Add, float);
}
@@ -220,30 +220,43 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2,
TfLiteTensor* output) {
if (output->type == kTfLiteUInt8) {
-#define TF_LITE_ADD(type, opname) \
- type::opname( \
- data->left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- data->input1_offset, data->input1_multiplier, data->input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- data->input2_offset, data->input2_multiplier, data->input2_shift, \
- data->output_offset, data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = data->left_shift; \
+ op_params.input1_offset = data->input1_offset; \
+ op_params.input1_multiplier = data->input1_multiplier; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_offset = data->input2_offset; \
+ op_params.input2_multiplier = data->input2_multiplier; \
+ op_params.input2_shift = data->input2_shift; \
+ op_params.output_offset = data->output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
// The quantized version of Add doesn't support activations, so we
// always use BroadcastAdd.
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
} else {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow);
}
#undef TF_LITE_ADD
} else if (output->type == kTfLiteInt16) {
-#define TF_LITE_ADD(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- data->input1_shift, GetTensorData<int16_t>(input2), \
- GetTensorDims(input2), data->input2_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_shift = data->input2_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
// The quantized version of Add doesn't support activations, so we
// always use BroadcastAdd.
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 3425288f02..a11a59aa05 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -276,27 +275,33 @@ TfLiteStatus CheckLstmTensorDimensions(
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor,
- kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor,
- kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor,
- kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor,
- kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor,
- kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor,
- kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
- kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
- kFwProjectionBiasTensor);
-
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor,
- kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor,
- kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor,
- kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor,
- kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor,
- kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor,
- kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
- kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
- kBwProjectionBiasTensor);
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
+ kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
+ kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
+ kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
+ kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
+ kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
+ kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
+ kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
+ kFwProjectionBiasTensor));
+
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
+ kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
+ kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
+ kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
+ kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
+ kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
+ kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
+ kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
+ kBwProjectionBiasTensor));
// Check if Forward and Backward tensors match along required dimensions.
return kTfLiteOk;
@@ -334,7 +339,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell);
+ TF_LITE_ENSURE_OK(
+ context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
+ n_fw_cell));
// Get the pointer to output, state and scratch buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
@@ -404,7 +411,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
+ TF_LITE_ENSURE_OK(
+ context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
+ n_bw_cell));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index aa24c1f34c..517309a226 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdlib>
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 45ea8d0049..ad211e9c67 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index a4fe9e5550..6f174763df 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <algorithm>
#include <cassert>
#include <cmath>
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 16e5f1d065..21518156b8 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index bc5c3783fd..d7420ddd8e 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteDivParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_DIV(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv);
+void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_DIV(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, int32_t);
+ }
} else {
- TF_LITE_DIV(reference_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv, float);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, float);
+ }
} else {
- TF_LITE_DIV(optimized_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv, float);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, float);
+ }
}
}
#undef TF_LITE_DIV
@@ -115,11 +130,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else {
context->ReportError(
- context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.",
+ context,
+ "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc
index 276b8289fb..97aa2fe04e 100644
--- a/tensorflow/contrib/lite/kernels/div_test.cc
+++ b/tensorflow/contrib/lite/kernels/div_test.cc
@@ -52,6 +52,13 @@ class FloatDivOpModel : public BaseDivOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerDivOpModel : public BaseDivOpModel {
+ public:
+ using BaseDivOpModel::BaseDivOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
TEST(FloatDivOpTest, NoActivation) {
FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -75,7 +82,7 @@ TEST(FloatDivOpTest, ActivationRELU_N1_TO_1) {
}
TEST(FloatDivOpTest, VariousInputShapes) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -92,7 +99,7 @@ TEST(FloatDivOpTest, VariousInputShapes) {
}
TEST(FloatDivOpTest, WithBroadcast) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -108,6 +115,56 @@ TEST(FloatDivOpTest, WithBroadcast) {
}
}
+TEST(IntegerDivOpTest, NoActivation) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -15, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {5, -2, -3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, -1, 5, 1}));
+}
+
+TEST(IntegerDivOpTest, ActivationRELU_N1_TO_1) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -12, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, -15, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 0, 1}));
+}
+
+TEST(IntegerDivOpTest, VariousInputShapes) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 3, 8, 11, -20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 6, 5, -11, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 1, 0, 1, -1, 20}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerDivOpTest, WithBroadcast) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 21, 7, 8, 11, -123});
+ m.PopulateTensor<int32_t>(m.input2(), {3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41}))
+ << "With shape number " << i;
+ }
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index 4f0d020793..e542ad0765 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <utility>
+#include "tensorflow/contrib/lite/arena_planner.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -23,6 +24,16 @@ namespace tflite {
namespace eigen_support {
namespace {
+#ifndef EIGEN_DONT_ALIGN
+// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on
+// hardware architecture and build configurations.
+// If the static assertion fails, try to increase `kDefaultTensorAlignment` to
+// in `arena_planner.h` to 32 or 64.
+static_assert(
+ kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
+ "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
+#endif // EIGEN_DONT_ALIGN
+
// We have a single global threadpool for all convolution operations. This means
// that inferences started from different threads may block each other, but
// since the underlying resource of CPU cores should be consumed by the
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 0ba170a4da..b2dff87e62 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -29,7 +29,6 @@ limitations under the License.
// When indices are out of bound, the ops will not succeed.
//
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -112,8 +111,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// TODO(alanchiao): refactor scalar multiply into separate function
// for ease of adding a neon equivalent if ever necessary.
for (int j = 0; j < col_size; j++) {
+ const int8_t* value_ptr = reinterpret_cast<int8_t*>(value->data.uint8);
output->data.f[j + i * col_size] =
- value->data.uint8[j + idx * col_size] * scaling_factor;
+ value_ptr[j + idx * col_size] * scaling_factor;
}
}
}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 04657fd863..4a88d168c6 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 8});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
@@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
@@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index f8927a0799..0ef1a50b30 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -44,6 +44,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ if (params->narrow_range) {
+ context->ReportError(
+ context,
+ "narrow_range FakeQuant is not currently supported at runtime. "
+ "narrow_range is only meant to be applied to weights, not activations");
+ return kTfLiteError;
+ }
+
OpContext op_context(context, node);
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
op_context.output->type = op_context.input->type;
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 6c9a845bd1..bc370608c0 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -71,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
gemm_support::IncrementUsageCounter(context);
- auto* op_data = new OpData;
+ auto* op_data = new OpData();
context->AddTensors(context, 1, &op_data->input_quantized_index);
return op_data;
}
@@ -152,10 +151,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
}
- // Resize output to the same as input (except the last dimension which is
- // determined by the number of units).
- TfLiteIntArray* output_size_array = TfLiteIntArrayCopy(input->dims);
- output_size_array->data[input->dims->size - 1] = num_units;
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index a6b6b2f497..ec94905697 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -207,7 +207,6 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
- std::vector<int> GetOutputSize() { return GetTensorShape(output_); }
};
class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
@@ -299,7 +298,6 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
- std::vector<int> GetOutputSize() { return GetTensorShape(output_); }
int input_size() { return input_size_; }
int num_units() { return units_; }
@@ -374,7 +372,6 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
m.Invoke();
- EXPECT_THAT(m.GetOutputSize(), ElementsAre(2, 3));
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
@@ -393,7 +390,6 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
m.Invoke();
- EXPECT_THAT(m.GetOutputSize(), ElementsAre(2, 1));
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
}
@@ -580,10 +576,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
- // batches. All we care is that the input size is the last dimension.
+ // batches. All we care is that the input can be evenly distributed in
+ // batches. In this case, we need the input to have multiples of '2'.
FloatFullyConnectedOpModel m(GetRegistration(),
/*units=*/3, /*batches=*/2,
- /*input=*/{TensorType_FLOAT32, {1, 2, 1, 10}});
+ /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
@@ -598,7 +595,6 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
m.Invoke();
- EXPECT_THAT(m.GetOutputSize(), ElementsAre(1, 2, 1, 3));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({
24, 25, 26, // first batch
58, 59, 60, // second batch
@@ -608,7 +604,7 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches=*/2,
- /*input=*/{TensorType_UINT8, {1, 2, 1, 10}, -63.5, 64},
+ /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
// input_product_scale < output_scale was not true.
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index 41211d41aa..f37c66acb3 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -31,7 +31,6 @@ limitations under the License.
// Each item indicates whether the corresponding lookup has a returned value.
// 0 for missing key, 1 for found key.
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 7962fcbc9d..3a855fe3dd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -232,6 +232,7 @@ cc_library(
cc_test(
name = "tensor_test",
srcs = ["tensor_test.cc"],
+ tags = ["no_oss"],
deps = [
":reference",
"@com_google_googletest//:gtest",
@@ -260,6 +261,7 @@ cc_library(
cc_test(
name = "quantization_util_test",
srcs = ["quantization_util_test.cc"],
+ tags = ["no_oss"],
deps = [
":quantization_util",
"@com_google_googletest//:gtest",
@@ -505,7 +507,10 @@ cc_test(
"//conditions:default": [],
}),
linkstatic = 1,
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":tensor_utils",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -517,6 +522,7 @@ cc_test(
cc_test(
name = "depthwiseconv_float_test",
srcs = ["depthwiseconv_float_test.cc"],
+ tags = ["no_oss"],
deps = [
":optimized_base",
":reference_base",
@@ -529,6 +535,7 @@ cc_test(
cc_test(
name = "depthwiseconv_quantized_test",
srcs = ["depthwiseconv_quantized_test.cc"],
+ tags = ["no_oss"],
deps = [
":optimized_base",
":reference_base",
@@ -541,7 +548,10 @@ cc_test(
cc_test(
name = "resize_bilinear_test",
srcs = ["resize_bilinear_test.cc"],
- tags = ["tflite_not_portable"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
deps = [
":optimized_base",
":reference_base",
@@ -557,6 +567,7 @@ cc_test(
srcs = [
"softmax_quantized_test.cc",
],
+ tags = ["no_oss"],
deps = [
":optimized_base",
":quantization_util",
@@ -572,7 +583,10 @@ cc_test(
srcs = [
"logsoftmax_quantized_test.cc",
],
- tags = ["tflite_not_portable"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
deps = [
":optimized_base",
":quantization_util",
@@ -585,6 +599,7 @@ cc_test(
cc_test(
name = "log_quantized_test",
srcs = ["log_quantized_test.cc"],
+ tags = ["no_oss"],
deps = [
":optimized_base",
":reference_base",
@@ -611,6 +626,7 @@ cc_library(
cc_test(
name = "batch_to_space_nd_test",
srcs = ["batch_to_space_nd_test.cc"],
+ tags = ["no_oss"],
deps = [
":optimized_base",
"@com_google_googletest//:gtest_main",
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index b86ca49c11..310a8980e6 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -127,6 +127,139 @@ int CountLeadingZeros(T integer_input) {
return leading_zeros;
}
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(
+ const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
+ NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
+ auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
+
+ // Copy dims to desc, calculating strides.
+ int desc0_stride = 1;
+ int desc1_stride = 1;
+ for (int i = N - 1; i >= 0; --i) {
+ desc0_out->extents[i] = extended_input0_shape.Dims(i);
+ desc0_out->strides[i] = desc0_stride;
+ desc0_stride *= extended_input0_shape.Dims(i);
+ desc1_out->extents[i] = extended_input1_shape.Dims(i);
+ desc1_out->strides[i] = desc1_stride;
+ desc1_stride *= extended_input1_shape.Dims(i);
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = extended_input0_shape.Dims(i);
+ const int extent1 = extended_input1_shape.Dims(i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index a0e382edb6..200f2f1515 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -255,14 +255,6 @@ void LstmStep(
output_state_ptr);
}
-// TODO(alanchiao): move this to tensor_utils.
-void VectorMultiply(const int8_t* vector, const int v_size, const float scale,
- float* result) {
- for (int i = 0; i < v_size; ++i) {
- *result++ = scale * *vector++;
- }
-}
-
void LstmStep(
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
float input_to_input_weights_scale,
@@ -415,8 +407,9 @@ void LstmStep(
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole && !is_cell_state_all_zeros) {
- VectorMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
@@ -427,8 +420,9 @@ void LstmStep(
// For each batch and cell: update forget gate.
if (use_peephole && !is_cell_state_all_zeros) {
- VectorMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
@@ -459,8 +453,9 @@ void LstmStep(
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update the output gate.
if (use_peephole && !is_cell_state_all_zeros) {
- VectorMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 6db41d7961..d5503073a7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -55,6 +55,245 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 8c57c987d7..420bc68b43 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -342,6 +342,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit,
}
}
+void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
+ const float scale, float* result) {
+ // Here the assumption is that each buffer is 4-byte aligned.
+ const int kWeightsPerUint32 = 4;
+ TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0);
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int kWeightsPerNeonLane = 16;
+ const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1));
+
+ // Create a vector of 4 floats with the scale value.
+ const float32x4_t scale_f32x4 = vdupq_n_f32(scale);
+ int v = 0;
+ for (; v < postamble_start; v += kWeightsPerNeonLane) {
+ // Load int8 values, sixteen at a time.
+ const int8x16_t v_i8x16 = vld1q_s8(vector + v);
+ // Split it into two components of size eight.
+ const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16);
+ const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16);
+ // Convert both components to int16 first.
+ const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8);
+ const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8);
+ // Split each of them into two components each.
+ const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8);
+ const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8);
+ const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8);
+ // Convert these to int32 and then to float.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4));
+ float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4);
+ v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ vst1q_f32(result + v + 8, v2_f32x4);
+ vst1q_f32(result + v + 12, v3_f32x4);
+ }
+
+ if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) {
+ // Load eight int8 values, if there is at least eight remaining.
+ const int8x8_t v_i8x8 = vld1_s8(vector + v);
+ // Convert them to int16 first.
+ const int16x8_t v_i16x8 = vmovl_s8(v_i8x8);
+ // Split it into two components.
+ const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8);
+ // Convert the components two floats.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ v += (kWeightsPerNeonLane >> 1);
+ }
+
+ // Postamble loop.
+ for (; v < v_size; v++) {
+ result[v] = scale * vector[v];
+ }
+}
+
void NeonSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
float* max, float* scaling_factor) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 7a5a8fc541..45c9f65b64 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -105,6 +105,10 @@ bool IsZeroVector(const float* vector, int v_size) {
return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
}
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result) {
+ NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
+}
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index c857fdf699..78567d52ea 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -42,10 +42,12 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
+using reference_ops::BroadcastAdd4DSlow;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
using reference_ops::BroadcastLess;
using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastSub4DSlow;
using reference_ops::Concatenation;
using reference_ops::DepthConcatenation;
using reference_ops::Dequantize;
@@ -217,98 +219,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
- }
-}
-
inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
for (int i = 0; i < 4; i++) {
if (dims1.sizes[i] != dims2.sizes[i]) {
@@ -2478,20 +2388,17 @@ inline void L2Normalization(const uint8* input_data,
}
}
-inline void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Add");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
- const auto activation_min = vdupq_n_f32(output_activation_min);
- const auto activation_max = vdupq_n_f32(output_activation_max);
+ const auto activation_min = vdupq_n_f32(params.float_activation_min);
+ const auto activation_max = vdupq_n_f32(params.float_activation_max);
for (; i <= size - 16; i += 16) {
auto a10 = vld1q_f32(input1_data + i);
auto a11 = vld1q_f32(input1_data + i + 4);
@@ -2530,29 +2437,26 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims,
for (; i < size; i++) {
auto x = input1_data[i] + input2_data[i];
- output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
- output_activation_max);
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
}
}
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
-inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
- int32 input1_offset, int32 input1_multiplier,
- int input1_shift, const uint8* input2_data,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data) {
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
int i = 0;
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
#ifdef USE_NEON
- const auto output_activation_min_vector = vdup_n_u8(output_activation_min);
- const auto output_activation_max_vector = vdup_n_u8(output_activation_max);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
for (; i <= size - 8; i += 8) {
const auto input1_val_original = vld1_u8(input1_data + i);
const auto input2_val_original = vld1_u8(input2_data + i);
@@ -2561,9 +2465,9 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
const auto input2_val_s16 =
vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
const auto input1_val =
- vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
+ vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
const auto input2_val =
- vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
+ vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
const auto input1_val_high = vget_high_s16(input1_val);
const auto input1_val_low = vget_low_s16(input1_val);
const auto input2_val_high = vget_high_s16(input2_val);
@@ -2572,32 +2476,32 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
auto x12 = vmovl_s16(input1_val_high);
auto x21 = vmovl_s16(input2_val_low);
auto x22 = vmovl_s16(input2_val_high);
- const auto left_shift_dup = vdupq_n_s32(left_shift);
+ const auto left_shift_dup = vdupq_n_s32(params.left_shift);
x11 = vshlq_s32(x11, left_shift_dup);
x12 = vshlq_s32(x12, left_shift_dup);
x21 = vshlq_s32(x21, left_shift_dup);
x22 = vshlq_s32(x22, left_shift_dup);
- x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
- x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
- x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
- x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
- const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
- const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
+ x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
+ const auto input1_shift_dup = vdupq_n_s32(params.input1_shift);
+ const auto input2_shift_dup = vdupq_n_s32(params.input2_shift);
x11 = vshlq_s32(x11, input1_shift_dup);
x12 = vshlq_s32(x12, input1_shift_dup);
x21 = vshlq_s32(x21, input2_shift_dup);
x22 = vshlq_s32(x22, input2_shift_dup);
auto s1 = vaddq_s32(x11, x21);
auto s2 = vaddq_s32(x12, x22);
- s1 = vqrdmulhq_n_s32(s1, output_multiplier);
- s2 = vqrdmulhq_n_s32(s2, output_multiplier);
+ s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
using gemmlowp::RoundingDivideByPOT;
- s1 = RoundingDivideByPOT(s1, output_shift);
- s2 = RoundingDivideByPOT(s2, output_shift);
+ s1 = RoundingDivideByPOT(s1, -params.output_shift);
+ s2 = RoundingDivideByPOT(s2, -params.output_shift);
const auto s1_narrowed = vmovn_s32(s1);
const auto s2_narrowed = vmovn_s32(s2);
const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
- vdupq_n_s16(output_offset));
+ vdupq_n_s16(params.output_offset));
const auto clamped =
vmax_u8(output_activation_min_vector,
vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
@@ -2606,101 +2510,74 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
#endif // NEON
for (; i < size; ++i) {
- const int32 input1_val = input1_offset + input1_data[i];
- const int32 input2_val = input2_offset + input2_data[i];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output = std::min(
- output_activation_max, std::max(output_activation_min, raw_output));
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
output_data[i] = static_cast<uint8>(clamped_output);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Add/8bit");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
- AddElementwise(flat_size, left_shift, input1_data, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/Int16");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -2709,195 +2586,59 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Add(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Add/int32");
-
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] + input2_data[i], output_activation_min,
- output_activation_max);
- }
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
-
- Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
- input2_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-void Add(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() + input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() + scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar + input2_map.array();
} else {
// Should not come here.
TFLITE_DCHECK(false);
}
+ output_map = output_map.cwiseMax(params.quantized_activation_min);
+ output_map = output_map.cwiseMin(params.quantized_activation_max);
}
-// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
- }
-}
-
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
// Fivefold nested loops. The second input resets its position for each
// iteration of the second loop. The first input resets its position at the
// beginning of the fourth loop. The innermost loop is an elementwise add of
@@ -2905,82 +2646,29 @@ inline void BroadcastAddFivefold(
uint8* output_data_ptr = output_data;
const uint8* input1_data_ptr = input1_data;
const uint8* input2_data_reset = input2_data;
- for (int i4 = 0; i4 < y4; ++i4) {
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
const uint8* input2_data_ptr;
- for (int i3 = 0; i3 < y3; ++i3) {
+ for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- AddElementwise(
- y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier,
- input1_shift, input2_data_ptr, input2_offset, input2_multiplier,
- input2_shift, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data_ptr);
- input2_data_ptr += y0;
- output_data_ptr += y0;
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
- input1_data_ptr += y0;
+ input1_data_ptr += y4;
}
}
input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
@@ -3305,122 +2993,78 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
// TODO(aselle): This is not actually optimized yet.
-inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
}
}
-// TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Sub");
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sub = scaled_input1_val - scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sub, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
+ output_map.array() = input1_map.array() - input2_map.array();
+ } else if (input1_shape.FlatSize() == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar - input2_map.array();
+ } else if (input2_shape.FlatSize() == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() - scalar;
+ } else {
+ BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
}
}
@@ -5863,63 +5507,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
- }
- }
- }
-}
-
-template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
-
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
- output_map.array() = input1_map.array() - input2_map.array();
- } else if (FlatSize(input1_dims) == 1) {
- auto scalar = input1_data[0];
- output_map.array() = scalar - input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
- auto scalar = input2_data[0];
- output_map.array() = input1_map.array() - scalar;
- } else {
- GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
- output_data, output_dims);
- }
-}
-
-template <typename T>
void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, T* output_data,
const Dims<4>& output_dims) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index f14667090f..db7926df9a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -124,6 +124,12 @@ void PortableCopyVector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void PortableZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Limit a float input f between +abs_limit and -abs_limit.
float PortableClip(float f, float abs_limit);
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 525857a2e6..9b3f1823dc 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -28,8 +28,9 @@ namespace tflite {
// Given the min and max values of a float array, return
// reasonable quantization parameters to use for this array.
template <typename T>
-QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
- const T qmin = std::numeric_limits<T>::min();
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax,
+ bool narrow_range) {
+ const T qmin = std::numeric_limits<T>::min() + (narrow_range ? 1 : 0);
const T qmax = std::numeric_limits<T>::max();
const double qmin_double = qmin;
const double qmax_double = qmax;
@@ -97,6 +98,11 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
return quantization_params;
}
+template <typename T>
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
+ return ChooseQuantizationParams<T>(rmin, rmax, false);
+}
+
// Converts a floating-point number to an integer. For all inputs x where
// static_cast<IntOut>(x) is legal according to the C++ standard, the result
// is identical to that cast (i.e. the result is x with its fractional part
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index f715d34bc1..bcf5e4e4f6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -63,6 +63,240 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<T>::min();
+ op_params.quantized_activation_max = std::numeric_limits<T>::max();
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index ccf112c990..7ead449ca8 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -195,6 +195,13 @@ void PortableZeroVector(float* vector, int v_size) {
memset(vector, 0, v_size * sizeof(float));
}
+void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
+ const float scale, float* result) {
+ for (int v = 0; v < v_size; ++v) {
+ *result++ = scale * *vector++;
+ }
+}
+
void PortableClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
for (int v = 0; v < v_size; v++) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index d2e1fecd25..d3a4fa8507 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -96,6 +96,10 @@ void PortableSub1Vector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void PortableZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Clip elements of a vector using a abs_limit value.
void PortableClipVector(const float* vector, int v_size, float abs_limit,
float* result);
@@ -199,6 +203,12 @@ void ZeroVector(float* vector, int v_size) {
PortableZeroVector(vector, v_size);
}
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result) {
+ PortableVectorScalarMultiply(vector, v_size, scale, result);
+}
+
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
PortableClipVector(vector, v_size, abs_limit, result);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 2d40f1769b..31a54c2b62 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -158,98 +158,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
- }
-}
-
inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
@@ -1065,114 +973,108 @@ inline void L2Normalization(const uint8* input_data,
}
template <typename T>
-inline void Add(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] + input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] + input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < size; i++) {
+ auto x = input1_data[i] + input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
+ }
}
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- const int batches =
- MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
- const int height =
- MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
- const int width =
- MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
- const int depth =
- MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- const int32 input1_val =
- input1_offset + input1_data[Offset(input1_dims, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[Offset(input2_dims, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+// Element-wise add that can often be used for inner loop of broadcast add as
+// well as the non-broadcast add.
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[i] = static_cast<uint8>(clamped_output);
}
}
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -1181,42 +1083,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
-
- Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
- input2_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1229,49 +1117,77 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
}
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
-
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1284,33 +1200,37 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1318,117 +1238,62 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
}
}
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
-
- int sb1 = y0;
- int sa2 = y0;
- int sb2 = y0 * y1;
- int sa3 = y0 * y2;
- int sa4 = y0 * y2 * y3;
- int sb4 = y0 * y1 * y2;
-
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise add of
+ // sections of the arrays.
uint8* output_data_ptr = output_data;
- for (int i4 = 0; i4 < y4; ++i4) {
- for (int i3 = 0; i3 < y3; ++i3) {
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- for (int i0 = 0; i0 < y0; ++i0) {
- const int32 input1_val =
- input1_offset +
- input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0];
- const int32 input2_val =
- input2_offset +
- input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- *output_data_ptr = static_cast<uint8>(clamped_output);
- ++output_data_ptr;
- }
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
+ input1_data_ptr += y4;
}
}
+ input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
template <typename T>
inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
@@ -1654,10 +1519,11 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
@@ -1666,15 +1532,35 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
@@ -1682,16 +1568,24 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
-
+// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1704,36 +1598,35 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
-
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1746,33 +1639,37 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sub = scaled_input1_val - scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sub, output_multiplier, kReverseShift * output_shift) +
- output_offset;
+ raw_sub, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1780,6 +1677,156 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
}
}
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ }
+ }
+ }
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
template <FusedActivationFunctionType Ac, typename Scalar>
void Concatenation(int concat_dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
@@ -1813,6 +1860,26 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
}
}
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ int outer_size = 1;
+ for (int i = dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ const int copy_size = FlatSize(**input_dims) / outer_size;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
@@ -3467,9 +3534,9 @@ inline bool Reduce(const In* input_data, const int* input_dims,
const int* output_dims, const int input_num_dims,
const int output_num_dims, const int* axis,
const int num_axis, int* input_iter,
- Out reducer(Out current, const In in), Out* output_data) {
+ Out reducer(const Out current, const In in),
+ Out* output_data) {
// Reset input iterator.
- TFLITE_DCHECK(input_num_dims > 0);
for (int idx = 0; idx < input_num_dims; ++idx) {
input_iter[idx] = 0;
}
@@ -3485,11 +3552,16 @@ inline bool Reduce(const In* input_data, const int* input_dims,
return true;
}
-inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis,
- int* out_axis, int* out_num_axis) {
+inline bool ResolveAxis(const int num_dims, const int* axis,
+ const int64_t num_axis, int* out_axis,
+ int* out_num_axis) {
*out_num_axis = 0; // Just in case.
+ // Short-circuit axis resolution for scalars; the axis will go unused.
+ if (num_dims == 0) {
+ return true;
+ }
// o(n^2) is fine since out_num_axis should be really small, mostly <= 4
- for (int idx = 0; idx < num_axis; ++idx) {
+ for (int64_t idx = 0; idx < num_axis; ++idx) {
// Handle negative index.
int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
TFLITE_DCHECK(current >= 0 && current < num_dims);
@@ -3515,7 +3587,7 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims,
const int output_num_dims, const int* axis,
const int num_axis, int* input_iter,
Out* output_data) {
- auto reducer = [](Out current, const In in) -> Out {
+ auto reducer = [](const Out current, const In in) -> Out {
const Out actual_in = static_cast<Out>(in);
return current + actual_in;
};
@@ -3524,6 +3596,24 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims,
output_data);
}
+template <typename T>
+inline bool InitTensorDataForReduce(const int* dims, const int num_dims,
+ const T init_value, T* data) {
+ size_t num_elements = 1;
+ for (int idx = 0; idx < num_dims; ++idx) {
+ size_t current = static_cast<size_t>(dims[idx]);
+ // Overflow prevention.
+ if (num_elements > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_elements *= current;
+ }
+ for (size_t idx = 0; idx < num_elements; ++idx) {
+ data[idx] = init_value;
+ }
+ return true;
+}
+
// Computes the sum of elements across dimensions given in axis.
template <typename T>
inline bool Sum(const T* input_data, const int* input_dims,
@@ -3532,17 +3622,9 @@ inline bool Sum(const T* input_data, const int* input_dims,
const int* axis, const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis) {
// Reset output data.
- size_t num_outputs = 1;
- for (int idx = 0; idx < output_num_dims; ++idx) {
- size_t current = static_cast<size_t>(output_dims[idx]);
- // Overflow prevention.
- if (num_outputs > std::numeric_limits<size_t>::max() / current) {
- return false;
- }
- num_outputs *= current;
- }
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- output_data[idx] = T();
+ if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(0),
+ output_data)) {
+ return false;
}
// Resolve axis.
@@ -3557,6 +3639,61 @@ inline bool Sum(const T* input_data, const int* input_dims,
num_resolved_axis, temp_index, output_data);
}
+// Computes the max of elements across dimensions given in axis.
+template <typename T>
+inline bool ReduceMax(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis) {
+ T init_value = std::numeric_limits<T>::lowest();
+ // Reset output data.
+ if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
+ output_data)) {
+ return false;
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ auto reducer = [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ };
+ return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, reducer, output_data);
+}
+
+// Computes the prod of elements across dimensions given in axis.
+template <typename T>
+inline bool ReduceProd(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis) {
+ // Reset output data.
+ if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(1),
+ output_data)) {
+ return false;
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ auto reducer = [](const T current, const T in) -> T { return in * current; };
+ return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, reducer, output_data);
+}
+
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis.
@@ -3649,38 +3786,6 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
- }
- }
- }
-}
-
-template <typename T>
void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, T* output_data,
const Dims<4>& output_dims) {
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 5160e22307..82f4503127 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -124,6 +124,10 @@ void Sub1Vector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void ZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Clip elements of a vector using a abs_limit value.
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result);
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index aa0d49ae4d..372a6efec5 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -32,6 +32,22 @@ TEST(uKernels, ClipTest) {
{0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
}
+TEST(uKernels, VectorScalarMultiply) {
+ constexpr int kVectorSize = 29;
+ static int8_t input[kVectorSize];
+ for (int i = 0; i < 29; ++i) {
+ input[i] = static_cast<int8_t>(i - 14);
+ }
+ const float scale = 0.1f;
+ std::vector<float> output(kVectorSize, 0.0f);
+ VectorScalarMultiply(input, kVectorSize, scale, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5,
+ -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4})));
+}
+
TEST(uKernels, IsZeroTest) {
constexpr int kVectorSize = 21;
static float zeros[kVectorSize] = {0.0};
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 737cfb69c9..c44698b677 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -119,6 +119,8 @@ class RuntimeShape {
// larger shapes are separately allocated.
static constexpr int kMaxSmallSize = 4;
+ RuntimeShape& operator=(RuntimeShape const&) = delete;
+
RuntimeShape() : size_(0) {}
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
@@ -135,6 +137,20 @@ class RuntimeShape {
BuildFrom(init_list);
}
+ // Avoid using this constructor. We should be able to delete it when C++17
+ // rolls out.
+ RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_ = new int32[size_];
+ }
+ std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
+ }
+
+ bool operator==(const RuntimeShape& comp) const {
+ return this->size_ == comp.size_ &&
+ std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -191,6 +207,16 @@ class RuntimeShape {
}
}
+ // This will probably be factored out. Old code made substantial use of 4-D
+ // shapes, and so this function is used to extend smaller shapes. Note that
+ // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
+ // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
+ // inputs should already be 4-D, so this function should not be needed.
+ inline static RuntimeShape ExtendedShape(int new_shape_size,
+ const RuntimeShape& shape) {
+ return RuntimeShape(new_shape_size, shape, 1);
+ }
+
inline void BuildFrom(const std::initializer_list<int> init_list) {
BuildFrom<const std::initializer_list<int>>(init_list);
}
@@ -208,7 +234,25 @@ class RuntimeShape {
return buffer_size;
}
+ bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
+
private:
+ // For use only by ExtendFrom(), written to guarantee (return-value) copy
+ // elision in C++17.
+ // This creates a shape padded to the desired size with the specified value.
+ RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
+ : size_(0) {
+ TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
+ TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
+ Resize(new_shape_size);
+ const int size_increase = new_shape_size - shape.DimensionsCount();
+ for (int i = 0; i < size_increase; ++i) {
+ SetDim(i, pad_value);
+ }
+ std::memcpy(DimsData() + size_increase, shape.DimsData(),
+ sizeof(int32) * shape.DimensionsCount());
+ }
+
int32 size_;
union {
int32 dims_[kMaxSmallSize];
@@ -234,7 +278,9 @@ inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return false;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(current != nullptr);
int carry = 1;
@@ -261,7 +307,9 @@ inline bool NextIndex(const int num_dims, const int* dims, int* current) {
inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
const int* index, const int num_axis,
const int* axis) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return 0;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(index != nullptr);
size_t offset = 0;
@@ -364,6 +412,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
// arrays.
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -374,6 +423,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -385,6 +435,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -397,6 +448,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2,
const RuntimeShape& check_shape_3) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -601,14 +653,74 @@ struct PoolParams {
int stride_width;
int filter_height;
int filter_width;
- // uint8, etc, inference params.
+ // uint8, etc, activation params.
int32 quantized_activation_min;
int32 quantized_activation_max;
- // float inference params.
+ // float activation params.
float float_activation_min;
float float_activation_max;
};
+enum class BroadcastableOpCategory : uint8 {
+ kNone,
+ kNonBroadcast, // Matching input shapes.
+ kFirstInputBroadcastsFast, // Fivefold nested loops.
+ kSecondInputBroadcastsFast, // Fivefold nested loops.
+ kGenericBroadcast, // Fall-back.
+};
+
+// For Add, Sub, Mul ops.
+struct ArithmeticParams {
+ // Shape dependent / common to data / op types.
+ BroadcastableOpCategory broadcast_category;
+ // uint8 inference params.
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // Add / Sub, not Mul, uint8 inference params.
+ int left_shift;
+ int32 input1_multiplier;
+ int input1_shift;
+ int32 input2_multiplier;
+ int input2_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+
+ // Processed output dimensions.
+ // Let input "a" be the one that broadcasts in the faster-changing dimension.
+ // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
+ // {b0, b1, b2, b3, b4},
+ // broadcast_shape[4] = b0 = a0.
+ // broadcast_shape[3] = b1; a1 = 1.
+ // broadcast_shape[2] = b2 = a2.
+ // broadcast_shape[1] = a3; b3 = 1.
+ // broadcast_shape[0] = b4 = a4.
+ int broadcast_shape[5];
+};
+
+template <typename T>
+inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+
+template <>
+inline void SetActivationParams(float min, float max,
+ ArithmeticParams* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <>
+inline void SetActivationParams(int32 min, int32 max,
+ ArithmeticParams* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 25d2dc2cdd..69523b02cc 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -50,7 +50,6 @@ limitations under the License.
// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
// A flattened tensor represents projected bit vectors.
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 3577ae6caa..ba251c451e 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -97,7 +96,7 @@ constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* op_data = new OpData;
+ auto* op_data = new OpData();
op_data->kernel_type = kTfLiteLSTMFullKernel;
context->AddTensors(context, /*tensors_to_add=*/7,
&op_data->scratch_tensor_index);
@@ -306,7 +305,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
// Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -846,7 +846,7 @@ enum OutputTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* op_data = new OpData;
+ auto* op_data = new OpData();
op_data->kernel_type = kTfLiteLSTMBasicKernel;
// `scratch_tensor_index` is unused in this kernel.
op_data->scratch_tensor_index = -1;
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0b7c56133e..0266f5fe57 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Unit test for TFLite LSTM op.
+//
+// TODO(alanchiao): add unit test with invalid input dimensions for this and its
+// variants.
#include <memory>
#include <vector>
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
new file mode 100644
index 0000000000..bb3416f6a6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pack {
+namespace {
+
+constexpr int kOutputTensor = 0;
+
+// Op data for pack op.
+struct OpData {
+ int values_count;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input0 = GetInput(context, node, 0);
+ TF_LITE_ENSURE(context, NumDimensions(input0) < 4);
+ TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ // Make sure all inputs have the same shape and type.
+ for (int i = 1; i < data->values_count; ++i) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
+ TF_LITE_ENSURE_EQ(context, input0->type, input->type);
+ }
+
+ // Resize output. rank R will become rank R + 1
+ const int dimension_size = NumDimensions(input0) + 1;
+ const TfLiteIntArray* input_shape = input0->dims;
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
+ int i = 0;
+ for (int index = 0; index < dimension_size; ++index) {
+ if (index == data->axis) {
+ output_shape->data[index] = data->values_count;
+ } else {
+ output_shape->data[index] = input_shape->data[i++];
+ }
+ }
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, output->type, input0->type);
+
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+template <typename T>
+void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
+ int values_count, int axis) {
+ VectorOfTensors<T> all_inputs(*context, *node->inputs);
+ reference_ops::Pack<T>(RemapDim(NumDimensions(output), axis),
+ all_inputs.data(), all_inputs.dims(), values_count,
+ GetTensorData<T>(output), GetTensorDims(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ switch (output->type) {
+ case kTfLiteFloat32: {
+ PackImpl<float>(context, node, output, data->values_count, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace pack
+
+TfLiteRegistration* Register_PACK() {
+ static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare,
+ pack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc
new file mode 100644
index 0000000000..485a50ad3a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pack_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class PackOpModel : public SingleOpModel {
+ public:
+ PackOpModel(const TensorData& input_template, int axis, int values_count) {
+ std::vector<std::vector<int>> all_input_shapes;
+ for (int i = 0; i < values_count; ++i) {
+ all_input_shapes.push_back(input_template.shape);
+ AddInput(input_template);
+ }
+ output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
+ input_template.max});
+ SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions,
+ CreatePackOptions(builder_, values_count, axis).Union());
+ BuildInterpreter(all_input_shapes);
+ }
+
+ void SetInput(int index, std::initializer_list<T> data) {
+ PopulateTensor(index, data);
+ }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int output_;
+};
+
+TEST(PackOpTest, FloatThreeInputs) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, FloatMultilDimensions) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+TEST(PackOpTest, IntThreeInputs) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, IntThreeInputsDifferentAxis) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, IntMultilDimensions) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 9b0487ae16..29a5be0683 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc
index 474d323bc3..74b3aef5bd 100644
--- a/tensorflow/contrib/lite/kernels/pow_test.cc
+++ b/tensorflow/contrib/lite/kernels/pow_test.cc
@@ -50,22 +50,22 @@ class PowOpModel : public SingleOpModel {
};
TEST(PowOpModel, Simple) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8));
}
TEST(PowOpModel, NegativeAndZeroValue) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {0, 2, -7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1));
@@ -98,10 +98,10 @@ TEST(PowOpModel, NegativeFloatTest) {
}
TEST(PowOpModel, BroadcastTest) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1}}, {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
- model.PopulateTensor<int32>(model.input2(), {4});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {4});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 31c331a8c6..e99f67c725 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -78,6 +78,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
size_t num_axis = NumElements(op_context->axis);
const TfLiteIntArray* input_dims = op_context->input->dims;
int input_num_dims = NumDimensions(op_context->input);
+ if (input_num_dims == 0) {
+ return context->ResizeTensor(context, op_context->output,
+ TfLiteIntArrayCreate(0));
+ }
const int* axis = GetTensorData<int>(op_context->axis);
if (op_context->params->keep_dims) {
TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
@@ -315,6 +319,99 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+template <KernelType kernel_type>
+TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_PROD(kernel_type, data_type) \
+ kernel_type::ReduceProd<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ // TODO(wangtz): uint8 reduce_prod is not yet supported.
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_PROD
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_MAX(kernel_type, data_type) \
+ kernel_type::ReduceMax<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+ op_context.output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+ op_context.output->params.zero_point);
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
+ break;
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_MAX
+ return kTfLiteOk;
+}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -331,9 +428,27 @@ TfLiteRegistration* Register_SUM_REF() {
return &r;
}
+TfLiteRegistration* Register_REDUCE_PROD_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalProd<reduce::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MAX_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalMax<reduce::kReference>};
+ return &r;
+}
+
// TODO(kanlig): add optimized implementation of Mean.
TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
+TfLiteRegistration* Register_REDUCE_PROD() {
+ return Register_REDUCE_PROD_REF();
+}
+TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
} // namespace builtin
} // namespace ops
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 9e946822c6..5d432d34ef 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -22,13 +22,14 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
class BaseOpModel : public SingleOpModel {
public:
- void SetAxis(std::initializer_list<int> data) { PopulateTensor(axis_, data); }
+ void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
template <class T>
- void SetInput(std::initializer_list<T> data) {
+ void SetInput(std::vector<T> data) {
PopulateTensor(input_, data);
}
@@ -110,14 +111,72 @@ class SumOpDynamicModel : public BaseOpModel {
}
};
+// Model for the tests case where axis is a const tensor.
+class ProdOpConstModel : public BaseOpModel {
+ public:
+ ProdOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class ProdOpDynamicModel : public BaseOpModel {
+ public:
+ ProdOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class MaxOpConstModel : public BaseOpModel {
+ public:
+ MaxOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class MaxOpDynamicModel : public BaseOpModel {
+ public:
+ MaxOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
// for quantized Add, the error shouldn't exceed step
float GetTolerance(int min, int max) { return (max - min) / 255.0; }
// Tests for reduce_mean
TEST(ConstFloatMeanOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
{4}, {1, 0, -3, -3}, false);
m.SetInput(data);
@@ -127,9 +186,9 @@ TEST(ConstFloatMeanOpTest, NotKeepDims) {
}
TEST(ConstFloatMeanOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.SetInput(data);
@@ -139,14 +198,24 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
}
+TEST(ConstFloatMeanOpTest, Scalar) {
+ std::vector<float> data = {3.27};
+ MeanOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {},
+ {0}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.27})));
+}
+
TEST(DynamicFloatMeanOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
false);
- std::initializer_list<int> axis = {1, 0, -3, -3};
+ std::vector<int> axis = {1, 0, -3, -3};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -155,13 +224,13 @@ TEST(DynamicFloatMeanOpTest, NotKeepDims) {
}
TEST(DynamicFloatMeanOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
true);
- std::initializer_list<int> axis = {0, 2};
+ std::vector<int> axis = {0, 2};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -171,10 +240,10 @@ TEST(DynamicFloatMeanOpTest, KeepDims) {
}
TEST(DynamicFloatMeanOpTest, Scale) {
- std::initializer_list<float> data = {9.527};
+ std::vector<float> data = {9.527};
MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
{TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
+ std::vector<int> axis = {0};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -185,7 +254,7 @@ TEST(DynamicFloatMeanOpTest, Scale) {
TEST(ConstUint8MeanOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
{TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
@@ -197,7 +266,7 @@ TEST(ConstUint8MeanOpTest, NotKeepDims) {
TEST(ConstUint8MeanOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
{TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
@@ -210,11 +279,11 @@ TEST(ConstUint8MeanOpTest, KeepDims) {
TEST(DynamicUint8MeanOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
- std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
{TensorType_UINT8, {2}, -5.0, 2.0},
{TensorType_INT32, {1}}, false);
- std::initializer_list<int> axis = {1};
+ std::vector<int> axis = {1};
m.SetAxis(axis);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
@@ -226,11 +295,11 @@ TEST(DynamicUint8MeanOpTest, NotKeepDims) {
TEST(DynamicUint8MeanOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
- std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
{TensorType_UINT8, {2}, -10.0, 12.0},
{TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
+ std::vector<int> axis = {0};
m.SetAxis(axis);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
@@ -243,9 +312,9 @@ TEST(DynamicUint8MeanOpTest, KeepDims) {
// Tests for reduce_sum
TEST(ConstFloatSumOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
{4}, {1, 0, -3, -3}, false);
m.SetInput(data);
@@ -256,9 +325,9 @@ TEST(ConstFloatSumOpTest, NotKeepDims) {
}
TEST(ConstFloatSumOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.SetInput(data);
@@ -269,13 +338,13 @@ TEST(ConstFloatSumOpTest, KeepDims) {
}
TEST(DynamicFloatSumOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
false);
- std::initializer_list<int> axis = {1, 0, -3, -3};
+ std::vector<int> axis = {1, 0, -3, -3};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -284,13 +353,23 @@ TEST(DynamicFloatSumOpTest, NotKeepDims) {
ElementsAreArray(ArrayFloatNear({144, 156})));
}
+TEST(ConstFloatSumOpTest, Scalar) {
+ std::vector<float> data = {17.};
+ SumOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, {0},
+ false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({17.})));
+}
+
TEST(DynamicFloatSumOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
- std::initializer_list<int> axis = {0, 2};
+ std::vector<int> axis = {0, 2};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -300,10 +379,10 @@ TEST(DynamicFloatSumOpTest, KeepDims) {
}
TEST(DynamicFloatSumOpTest, Scale) {
- std::initializer_list<float> data = {9.527};
+ std::vector<float> data = {9.527};
SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
{TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
+ std::vector<int> axis = {0};
m.SetAxis(axis);
m.SetInput(data);
m.Invoke();
@@ -313,7 +392,7 @@ TEST(DynamicFloatSumOpTest, Scale) {
TEST(ConstUint8SumOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
{TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
@@ -326,7 +405,7 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
TEST(ConstUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
{TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
@@ -339,11 +418,11 @@ TEST(ConstUint8SumOpTest, KeepDims) {
TEST(DynamicUint8SumOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
- std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
{TensorType_UINT8, {2}, -5.0, 2.0},
{TensorType_INT32, {1}}, false);
- std::initializer_list<int> axis = {1};
+ std::vector<int> axis = {1};
m.SetAxis(axis);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
@@ -355,11 +434,11 @@ TEST(DynamicUint8SumOpTest, NotKeepDims) {
TEST(DynamicUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
- std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
{TensorType_UINT8, {2}, -10.0, 12.0},
{TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
+ std::vector<int> axis = {0};
m.SetAxis(axis);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
@@ -369,6 +448,223 @@ TEST(DynamicUint8SumOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
}
+// Tests for reduce_prod
+
+TEST(ConstFloatProdOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3.162341376e+11, 1.9619905536e+12})));
+}
+
+TEST(ConstFloatProdOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(
+ ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
+}
+
+TEST(DynamicFloatProdOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3.16234143225e+11, 1.9619905536e+12})));
+}
+
+TEST(DynamicFloatProdOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
+ true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(
+ ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
+}
+
+TEST(DynamicFloatProdOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+// Tests for reduce_max
+
+TEST(ConstFloatMaxOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({23, 24})));
+}
+
+TEST(ConstFloatMaxOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({20, 22, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({23, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({20, 22, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8MaxOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MaxOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MaxOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MaxOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, Scalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14};
+ MaxOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 1994e85ce3..0b70bed308 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -91,6 +91,8 @@ TfLiteRegistration* Register_FLOOR();
TfLiteRegistration* Register_TILE();
TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_SUM();
+TfLiteRegistration* Register_REDUCE_PROD();
+TfLiteRegistration* Register_REDUCE_MAX();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
@@ -104,6 +106,7 @@ TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
+TfLiteRegistration* Register_PACK();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -182,6 +185,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
+ AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
+ AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
@@ -190,7 +195,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_POW, Register_POW());
- AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT());
+ AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
+ AddBuiltin(BuiltinOperator_PACK, Register_PACK());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 3287040695..99ecc16093 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -25,16 +25,11 @@ namespace builtin {
namespace reshape {
constexpr int kInputTensor = 0;
+constexpr int kShapeTensor = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
-
- // TODO(ahentz): we are often given a tensor with the shape but we only pay
- // attention to what the shape specified in 'params'.
- TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
+ TfLiteIntArray* output_shape) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -47,32 +42,76 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
num_input_elements *= SizeOfDimension(input, i);
}
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
int num_output_elements = 1;
int stretch_dim = -1;
- for (int i = 0; i < params->num_dimensions; ++i) {
- int value = params->shape[i];
+ for (int i = 0; i < output_shape->size; ++i) {
+ int value = output_shape->data[i];
if (value == -1) {
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
stretch_dim = i;
} else {
num_output_elements *= value;
- output_size->data[i] = value;
}
}
if (stretch_dim != -1) {
- output_size->data[stretch_dim] = num_input_elements / num_output_elements;
- num_output_elements *= output_size->data[stretch_dim];
+ output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_shape->data[stretch_dim];
}
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
- return context->ResizeTensor(context, output, output_size);
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context,
+ TfLiteNode* node) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
+ for (int i = 0; i < output_shape->size; ++i) {
+ output_shape->data[i] = shape->data.i32[i];
+ }
+ return ResizeOutput(context, node, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Attempt to use shape tensor if it exists.
+ if (NumInputs(node) == 2) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+ // Check if the shape tensor is valid.
+ if (shape->dims->size == 1 && shape->type == kTfLiteInt32) {
+ // Set the output tensor as dynamic if the shape isn't constnat.
+ if (!IsConstantTensor(shape)) {
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ // Shape is constant. Resize now.
+ return ResizeOutputWithShapeTensor(context, node);
+ }
+ }
+ // The function is returned above this line if the shape tensor is usable.
+ // Now fallback to the shape parameter in `TfLiteReshapeParams`.
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ output_shape->data[i] = params->shape[i];
+ }
+ return ResizeOutput(context, node, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node));
+ }
+
memcpy(output->data.raw, input->data.raw, input->bytes);
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index 404c32ad9c..7be5e66c16 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 1247525d41..77a1f59689 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -78,29 +78,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteSubParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_SUB(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t);
+ } else {
+ TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
+ }
} else {
- TF_LITE_SUB(reference_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t);
+ } else {
+ TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float);
+ } else {
+ TF_LITE_SUB(reference_ops, SubWithActivation, float);
+ }
} else {
- TF_LITE_SUB(optimized_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float);
+ } else {
+ TF_LITE_SUB(optimized_ops, SubWithActivation, float);
+ }
}
}
#undef TF_LITE_SUB
@@ -128,36 +146,43 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int input1_shift;
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier,
&input1_multiplier, &input1_shift);
- input1_shift *= -1;
int32 input2_multiplier;
int input2_shift;
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier,
&input2_multiplier, &input2_shift);
- input2_shift *= -1;
int32 output_multiplier;
int output_shift;
QuantizeMultiplierSmallerThanOneExp(real_output_multiplier,
&output_multiplier, &output_shift);
- output_shift *= -1;
int32 output_activation_min, output_activation_max;
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(left_shift, GetTensorData<uint8_t>(input1), \
- GetTensorDims(input1), input1_offset, input1_multiplier, \
- input1_shift, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, input2_multiplier, \
- input2_shift, output_offset, output_multiplier, output_shift, \
- output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+#define TF_LITE_SUB(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = input2_shift; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = output_multiplier; \
+ op_params.output_shift = output_shift; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
// The quantized version of Sub doesn't support activations, so we
// always use BroadcastSub.
if (kernel_type == kReference) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow);
} else {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow);
}
#undef TF_LITE_SUB
}
@@ -171,14 +196,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
context->ReportError(
- context, "output type %d is not supported, requires float|uint8 types.",
+ context,
+ "output type %d is not supported, requires float|uint8|int32 types.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/contrib/lite/kernels/sub_test.cc
index ff07aeec49..5978c574d3 100644
--- a/tensorflow/contrib/lite/kernels/sub_test.cc
+++ b/tensorflow/contrib/lite/kernels/sub_test.cc
@@ -52,6 +52,13 @@ class FloatSubOpModel : public BaseSubOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerSubOpModel : public BaseSubOpModel {
+ public:
+ using BaseSubOpModel::BaseSubOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
class QuantizedSubOpModel : public BaseSubOpModel {
public:
using BaseSubOpModel::BaseSubOpModel;
@@ -125,6 +132,57 @@ TEST(FloatSubOpModel, WithBroadcast) {
}
}
+TEST(IntegerSubOpModel, NoActivation) {
+ IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3}));
+}
+
+TEST(IntegerSubOpModel, ActivationRELU_N1_TO_1) {
+ IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1}));
+}
+
+TEST(IntegerSubOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerSubOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerSubOpModel, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerSubOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 22eebdd4ce..6d4912ce3a 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -16,7 +16,6 @@ limitations under the License.
// SVDF op that compresses a fully connected op via low-rank matrix
// factorization. See https://research.google.com/pubs/archive/43813.pdf for
// details.
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -105,7 +104,7 @@ constexpr int kStateTensor = 0;
constexpr int kOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* op_data = new OpData;
+ auto* op_data = new OpData();
op_data->float_weights_time_initialized = false;
context->AddTensors(context, /*tensors_to_add=*/4,
&op_data->scratch_tensor_index);
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 7182374a6f..a9baa5c698 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -22,7 +21,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -39,35 +37,9 @@ constexpr int kWeightsTensor = 1;
constexpr int kDataInputTensor = 2;
constexpr int kOutputTensor = 0;
-const int kTensorNotAllocated = -1;
-
-struct OpData {
- // IDs are the arbitrary identifiers used by TF Lite to identify and access
- // memory buffers.
- int im2col_id = kTensorNotAllocated;
-
- // im2col is the only temporary currently tracked, therefore always index 0.
- // If more temporaries are added, they should be properly tracked.
- int32_t im2col_index = 0;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- // This is a builtin op, so we don't use the contents in 'buffer', if any.
- // Instead, we allocate a new object to use as scratch space for im2col, and
- // to carry information from Prepare() to Eval().
- auto* data = new OpData;
- eigen_support::IncrementUsageCounter(context);
- return data;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- eigen_support::DecrementUsageCounter(context);
- delete reinterpret_cast<OpData*>(buffer);
-}
-
-TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- TfLiteTensor* output) {
+TfLiteStatus ResizeOutputShape(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
// Currently only support int32 for output shape.
if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "Output shape is %d, not int32.",
@@ -83,60 +55,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
return context->ResizeTensor(context, output, output_shape_array);
}
-// Allocate temporary im2col tensor.
-static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context,
- TfLiteNode* node) {
- OpData* data = reinterpret_cast<OpData*>(node->user_data);
- if (data->im2col_id == kTensorNotAllocated) {
- context->AddTensors(context, 1, &data->im2col_id);
- }
-
- TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[data->im2col_index] = data->im2col_id;
-
- return kTfLiteOk;
-}
-
-TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- const TfLiteTensor* weights,
- const TfLiteTensor* input,
- TfLiteTensor* im2col) {
- if (output_shape->type != kTfLiteInt32) {
- context->ReportError(context, "im2col shape is %d, not int32.",
- output_shape->type);
- return kTfLiteError;
- }
- TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
- TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
- im2col_shape_array->data[0] = output_shape->data.i32[0];
- im2col_shape_array->data[1] = output_shape->data.i32[1];
- im2col_shape_array->data[2] = output_shape->data.i32[2];
- const int input_depth = SizeOfDimension(input, 3);
- const int filter_width = SizeOfDimension(weights, 1);
- const int filter_height = SizeOfDimension(weights, 2);
- im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
-
- im2col->type = input->type;
- im2col->allocation_type = kTfLiteArenaRw;
- return context->ResizeTensor(context, im2col, im2col_shape_array);
-}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node));
-
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* im2col =
- &context->tensors[node->temporaries->data[user_data->im2col_index]];
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -153,15 +80,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
SizeOfDimension(weights, 3));
- if (IsConstantTensor(output_shape)) {
- TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
- TF_LITE_ENSURE_STATUS(
- ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
- } else {
- // Defer resizing until Eval().
+ if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(output);
+ return kTfLiteOk;
}
- return kTfLiteOk;
+ return ResizeOutputShape(context, output_shape, output);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -170,19 +93,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* im2col =
- &context->tensors[node->temporaries->data[user_data->im2col_index]];
+
const auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
- ResizeOutputTensor(context, output_shape, output));
- }
- if (IsDynamicTensor(im2col)) {
- TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
- weights, input, im2col));
+ ResizeOutputShape(context, output_shape, output));
}
// Get height and width of the output image.
@@ -201,12 +118,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
case kTfLiteFloat32:
- optimized_ops::TransposeConv(
+ reference_ops::TransposeConv(
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ // Last two args specify im2col which reference_ops ignores.
+ // (Note this does not lead to a performance regression, as the
+ // previous optimized version was just a copy of the reference code.)
+ // TODO(b/110208176): Allocate im2col tensors and switch to
+ // optimized_ops.
+ GetTensorData<float>(output), GetTensorDims(output));
break;
default:
context->ReportError(context, "Type %d, not currently supported.",
@@ -219,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace transpose_conv
TfLiteRegistration* Register_TRANSPOSE_CONV() {
- static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free,
- transpose_conv::Prepare, transpose_conv::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare,
+ transpose_conv::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index c741df19de..55df897180 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -25,49 +24,9 @@ namespace {
using ::testing::ElementsAreArray;
-class ConstTransposeConvOpModel : public SingleOpModel {
- // Just to be extra confusing, transpose_conv has an _input_ named
- // "output_shape". This input sets the shape of the output tensor of the op.
- // In this version of the test class, "output_shape" is a constant that must
- // be specified in the constructor.
- public:
- ConstTransposeConvOpModel(TfLiteRegistration* registration,
- std::initializer_list<int> input_shape,
- std::initializer_list<int> filter_shape,
- std::initializer_list<int> output_shape_data,
- Padding padding, int stride_w, int stride_h) {
- output_shape_ = AddConstInput(TensorType_INT32, output_shape_data,
- {static_cast<int>(output_shape_data.size())});
- filter_ = AddInput(TensorType_FLOAT32);
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
- CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
- .Union());
- resolver_ = absl::make_unique<SingleOpResolver>(
- BuiltinOperator_TRANSPOSE_CONV, registration);
- BuildInterpreter({{4}, filter_shape, input_shape});
- }
-
- int output_shape() { return output_shape_; }
- int filter() { return filter_; }
- int input() { return input_; }
-
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int output_shape_;
- int filter_;
- int input_;
- int output_;
-};
-
class TransposeConvOpModel : public SingleOpModel {
public:
- TransposeConvOpModel(TfLiteRegistration* registration,
- std::initializer_list<int> input_shape,
+ TransposeConvOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> filter_shape, Padding padding,
int stride_w, int stride_h) {
output_shape_ = AddInput(TensorType_INT32);
@@ -78,8 +37,6 @@ class TransposeConvOpModel : public SingleOpModel {
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
.Union());
- resolver_ = absl::make_unique<SingleOpResolver>(
- BuiltinOperator_TRANSPOSE_CONV, registration);
BuildInterpreter({{4}, filter_shape, input_shape});
}
@@ -97,15 +54,6 @@ class TransposeConvOpModel : public SingleOpModel {
int output_;
};
-const auto kKernelMap = new std::map<string, TfLiteRegistration*>({});
-
-class TransposeConvOpTest : public SingleOpTest {
- protected:
- const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
- return *kKernelMap;
- }
-};
-
// Test case:
// output = tf.nn.conv2d_backprop_input(
// tf.constant([ 1, 4, 4, 1 ]),
@@ -113,9 +61,8 @@ class TransposeConvOpTest : public SingleOpTest {
// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32),
// [1, 1, 1, 1 ],
// "SAME")
-TEST_P(TransposeConvOpTest, SimpleTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 3, 3, 1},
- Padding_SAME, 1, 1);
+TEST(TransposeConvOpModelTest, SimpleTest) {
+ TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(
@@ -128,21 +75,6 @@ TEST_P(TransposeConvOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
-// Test case: Same as above, but with a const "output_shape"
-TEST_P(TransposeConvOpTest, ConstSimpleTest) {
- ConstTransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 4, 4, 1},
- {1, 3, 3, 1}, Padding_SAME, 1, 1);
- m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
- m.PopulateTensor<float>(
- m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
- m.Invoke();
-
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
- 417, 330, 263, 446, 485, 365}));
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
-}
-
// Test case:
// filter = tf.constant(np.arange(1, 19),
// shape=[ 3, 3, 1, 2 ],
@@ -155,9 +87,8 @@ TEST_P(TransposeConvOpTest, ConstSimpleTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
-TEST_P(TransposeConvOpTest, TwoFiltersTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
- Padding_SAME, 1, 1);
+TEST(TransposeConvOpModelTest, TwoFiltersTest) {
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -185,9 +116,8 @@ TEST_P(TransposeConvOpTest, TwoFiltersTest) {
// "VALID")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
-TEST_P(TransposeConvOpTest, PaddingValidTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
- Padding_VALID, 1, 1);
+TEST(TransposeConvOpModelTest, PaddingValidTest) {
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -216,9 +146,8 @@ TEST_P(TransposeConvOpTest, PaddingValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST_P(TransposeConvOpTest, StrideValidTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {1, 3, 3, 1},
- Padding_VALID, 2, 2);
+TEST(TransposeConvOpModelTest, StrideValidTest) {
+ TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
@@ -241,9 +170,8 @@ TEST_P(TransposeConvOpTest, StrideValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST_P(TransposeConvOpTest, MultiChannelTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
- Padding_VALID, 2, 2);
+TEST(TransposeConvOpModelTest, MultiChannelTest) {
+ TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
8, 10, 12, 14, 16, 18});
@@ -259,24 +187,6 @@ TEST_P(TransposeConvOpTest, MultiChannelTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
}
-// Test case: Same as above, but with a const "output_shape"
-TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
- ConstTransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
- {1, 5, 5, 2}, Padding_VALID, 2, 2);
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
- m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
- m.Invoke();
-
- EXPECT_THAT(
- m.GetOutput(),
- ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9,
- 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72,
- 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44,
- 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}));
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
-}
-
// Test case:
// filter = tf.constant(np.random.randint(1, 10, size=9),
// shape=[ 3, 3, 1, 1 ],
@@ -289,9 +199,8 @@ TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
-TEST_P(TransposeConvOpTest, AccuracyTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 1, 2, 1}, {1, 3, 3, 1},
- Padding_SAME, 3, 3);
+TEST(TransposeConvOpModelTest, AccuracyTest) {
+ TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3);
m.PopulateTensor<int>(m.output_shape(), {1, 3, 4, 1});
m.PopulateTensor<float>(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4});
m.PopulateTensor<float>(m.input(), {323, 521});
@@ -303,10 +212,6 @@ TEST_P(TransposeConvOpTest, AccuracyTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
}
-INSTANTIATE_TEST_CASE_P(
- TransposeConvOpTest, TransposeConvOpTest,
- ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
-
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 32daf2bb02..0acd705950 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -274,7 +273,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 164a0cbd08..0d6d29a171 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>